From 2b8d89e30ebfe2272229a1eddd7542d7437c9924 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 28 Jul 2014 10:59:53 -0700 Subject: [PATCH 001/170] [SPARK-2523] [SQL] Hadoop table scan bug fixing In HiveTableScan.scala, ObjectInspector was created for all of the partition based records, which probably causes ClassCastException if the object inspector is not identical among table & partitions. This is the follow up with: https://github.com/apache/spark/pull/1408 https://github.com/apache/spark/pull/1390 I've run a micro benchmark in my local with 15000000 records totally, and got the result as below: With This Patch | Partition-Based Table | Non-Partition-Based Table ------------ | ------------- | ------------- No | 1927 ms | 1885 ms Yes | 1541 ms | 1524 ms It showed this patch will also improve the performance. PS: the benchmark code is also attached. (thanks liancheng ) ``` package org.apache.spark.sql.hive import org.apache.spark.SparkContext import org.apache.spark.SparkConf import org.apache.spark.sql._ object HiveTableScanPrepare extends App { case class Record(key: String, value: String) val sparkContext = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new LocalHiveContext(sparkContext) val rdd = sparkContext.parallelize((1 to 3000000).map(i => Record(s"$i", s"val_$i"))) import hiveContext._ hql("SHOW TABLES") hql("DROP TABLE if exists part_scan_test") hql("DROP TABLE if exists scan_test") hql("DROP TABLE if exists records") rdd.registerAsTable("records") hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (part1 string, part2 STRING) | ROW FORMAT SERDE | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' | STORED AS RCFILE """.stripMargin) hql("""CREATE TABLE scan_test (key STRING, value STRING) | ROW FORMAT SERDE | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' | STORED AS RCFILE """.stripMargin) for (part1 <- 2000 until 2001) { for (part2 <- 1 to 5) { hql(s"""from records | insert into table part_scan_test PARTITION (part1='$part1', part2='2010-01-$part2') | select key, value """.stripMargin) hql(s"""from records | insert into table scan_test select key, value """.stripMargin) } } } object HiveTableScanTest extends App { val sparkContext = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new LocalHiveContext(sparkContext) import hiveContext._ hql("SHOW TABLES") val part_scan_test = hql("select key, value from part_scan_test") val scan_test = hql("select key, value from scan_test") val r_part_scan_test = (0 to 5).map(i => benchmark(part_scan_test)) val r_scan_test = (0 to 5).map(i => benchmark(scan_test)) println("Scanning Partition-Based Table") r_part_scan_test.foreach(printResult) println("Scanning Non-Partition-Based Table") r_scan_test.foreach(printResult) def printResult(result: (Long, Long)) { println(s"Duration: ${result._1} ms Result: ${result._2}") } def benchmark(srdd: SchemaRDD) = { val begin = System.currentTimeMillis() val result = srdd.count() val end = System.currentTimeMillis() ((end - begin), result) } } ``` Author: Cheng Hao Closes #1439 from chenghao-intel/hadoop_table_scan and squashes the following commits: 888968f [Cheng Hao] Fix issues in code style 27540ba [Cheng Hao] Fix the TableScan Bug while partition serde differs 40a24a7 [Cheng Hao] Add Unit Test --- .../apache/spark/sql/hive/TableReader.scala | 113 +++++++++++++----- .../sql/hive/execution/HiveTableScan.scala | 90 ++------------ ...t_serde-0-8caed2a6e80250a6d38a59388679c298 | 2 + .../hive/execution/HiveTableScanSuite.scala | 48 ++++++++ 4 files changed, 138 insertions(+), 115 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c3942578d6b5a..82c88280d7754 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -24,6 +24,8 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -31,13 +33,16 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} +import org.apache.spark.sql.catalyst.types.DataType + /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[_] + def makeRDDForTable(hiveTable: HiveTable): RDD[Row] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] } @@ -46,7 +51,10 @@ private[hive] sealed trait TableReader { * data warehouse directory. */ private[hive] -class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) +class HadoopTableReader( + @transient attributes: Seq[Attribute], + @transient relation: MetastoreRelation, + @transient sc: HiveContext) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def hiveConf = _broadcastedHiveConf.value.value - override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[_] = { + filterOpt: Option[PathFilter]): RDD[Row] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") // Create local references to member variables, so that the entire `this` object won't be // serialized in the closure below. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val tablePath = hiveTable.getPath @@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val attrsWithIndex = attributes.zipWithIndex + val mutableRow = new GenericMutableRow(attrsWithIndex.length) val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - // Deserialize each Writable to get the row value. - iter.map { - case v: Writable => deserializer.deserialize(v) - case value => - sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") - } + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon * subdirectory of each partition being read. If None, then all files are accepted. */ def makeRDDForPartitionedTable( - partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[_] = { - + partitionToDeserializer: Map[HivePartition, + Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[Row] = { val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = partition.getPartitionPath @@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon } // Create local references so that the outer object isn't serialized. - val tableDesc = _tableDesc + val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer + val mutableRow = new GenericMutableRow(attributes.length) + + // split the attributes (output schema) into 2 categories: + // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the + // index of the attribute in the output Row. + val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { + relation.partitionKeys.indexOf(attr._1) >= 0 + }) + + def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { + partitionKeys.foreach { case (attr, ordinal) => + // get partition key ordinal for a given attribute + val partOridinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + } + } + // fill the partition key for the given MutableRow Object + fillPartitionKeys(partValues, mutableRow) val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) hivePartitionRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val rowWithPartArr = new Array[Object](2) - - // The update and deserializer initialization are intentionally - // kept out of the below iter.map loop to save performance. - rowWithPartArr.update(1, partValues) val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // Map each tuple to a row object - iter.map { value => - val deserializedRow = deserializer.deserialize(value) - rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.asInstanceOf[Object] - } + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) } }.toSeq // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Object](sc.sparkContext) + new EmptyRDD[Row](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon // Only take the value (skip the key) because Hive works only with values. rdd.map(_._2) } - } -private[hive] object HadoopTableReader { +private[hive] object HadoopTableReader extends HiveInspectors { /** * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. @@ -241,4 +254,40 @@ private[hive] object HadoopTableReader { val bufferSize = System.getProperty("spark.buffer.size", "65536") jobConf.set("io.file.buffer.size", bufferSize) } + + /** + * Transform the raw data(Writable object) into the Row object for an iterable input + * @param iter Iterable input which represented as Writable object + * @param deserializer Deserializer associated with the input writable object + * @param attrs Represents the row attribute names and its zero-based position in the MutableRow + * @param row reusable MutableRow object + * + * @return Iterable Row object that transformed from the given iterable input. + */ + def fillObject( + iter: Iterator[Writable], + deserializer: Deserializer, + attrs: Seq[(Attribute, Int)], + row: GenericMutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] + // get the field references according to the attributes(output of the reader) required + val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + + // Map each tuple to a row object + iter.map { value => + val raw = deserializer.deserialize(value) + var idx = 0; + while (idx < fieldRefs.length) { + val fieldRef = fieldRefs(idx)._1 + val fieldIdx = fieldRefs(idx)._2 + val fieldValue = soi.getStructFieldData(raw, fieldRef) + + row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) + + idx += 1 + } + + row: Row + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index e7016fa16eea9..8920e2a76a27f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ -import org.apache.spark.util.MutablePair /** * :: DeveloperApi :: @@ -50,8 +49,7 @@ case class HiveTableScan( relation: MetastoreRelation, partitionPruningPred: Option[Expression])( @transient val context: HiveContext) - extends LeafNode - with HiveInspectors { + extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") @@ -67,42 +65,7 @@ case class HiveTableScan( } @transient - private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context) - - /** - * The hive object inspector for this table, which can be used to extract values from the - * serialized row representation. - */ - @transient - private[this] lazy val objectInspector = - relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] - - /** - * Functions that extract the requested attributes from the hive output. Partitioned values are - * casted from string to its declared data type. - */ - @transient - protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { - attributes.map { a => - val ordinal = relation.partitionKeys.indexOf(a) - if (ordinal >= 0) { - val dataType = relation.partitionKeys(ordinal).dataType - (_: Any, partitionKeys: Array[String]) => { - castFromString(partitionKeys(ordinal), dataType) - } - } else { - val ref = objectInspector.getAllStructFieldRefs - .find(_.getFieldName == a.name) - .getOrElse(sys.error(s"Can't find attribute $a")) - val fieldObjectInspector = ref.getFieldObjectInspector - - (row: Any, _: Array[String]) => { - val data = objectInspector.getStructFieldData(row, ref) - unwrapData(data, fieldObjectInspector) - } - } - } - } + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -114,6 +77,7 @@ case class HiveTableScan( val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") if (attributes.size == relation.output.size) { + // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` ColumnProjectionUtils.setFullyReadColumns(hiveConf) } else { ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) @@ -140,12 +104,6 @@ case class HiveTableScan( addColumnMetadataToConf(context.hiveconf) - private def inputRdd = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) - } - /** * Prunes partitions not involve the query plan. * @@ -169,44 +127,10 @@ case class HiveTableScan( } } - override def execute() = { - inputRdd.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val mutableRow = new GenericMutableRow(attributes.length) - val mutablePair = new MutablePair[Any, Array[String]]() - val buffered = iterator.buffered - - // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern - // matching are avoided intentionally. - val rowsAndPartitionKeys = buffered.head match { - // With partition keys - case _: Array[Any] => - buffered.map { case array: Array[Any] => - val deserializedRow = array(0) - val partitionKeys = array(1).asInstanceOf[Array[String]] - mutablePair.update(deserializedRow, partitionKeys) - } - - // Without partition keys - case _ => - val emptyPartitionKeys = Array.empty[String] - buffered.map { deserializedRow => - mutablePair.update(deserializedRow, emptyPartitionKeys) - } - } - - rowsAndPartitionKeys.map { pair => - var i = 0 - while (i < attributes.length) { - mutableRow(i) = attributeFunctions(i)(pair._1, pair._2) - i += 1 - } - mutableRow: Row - } - } - } + override def execute() = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output = attributes diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 new file mode 100644 index 0000000000000..f369f21e1833f --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 @@ -0,0 +1,2 @@ +100 100 2010-01-01 +200 200 2010-01-02 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala new file mode 100644 index 0000000000000..bcb00f871d185 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.test.TestHive + +class HiveTableScanSuite extends HiveComparisonTest { + // MINOR HACK: You must run a query before calling reset the first time. + TestHive.hql("SHOW TABLES") + TestHive.reset() + + TestHive.hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + | ROW FORMAT SERDE + | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + | STORED AS RCFILE + """.stripMargin) + TestHive.hql("""FROM src + | INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + | SELECT 100,100 LIMIT 1 + """.stripMargin) + TestHive.hql("""ALTER TABLE part_scan_test SET SERDE + | 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' + """.stripMargin) + TestHive.hql("""FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + | SELECT 200,200 LIMIT 1 + """.stripMargin) + + createQueryTest("partition_based_table_scan_with_different_serde", + "SELECT * from part_scan_test", false) +} From 255b56f9f530e8594a7e6055ae07690454c66799 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 28 Jul 2014 11:34:19 -0700 Subject: [PATCH 002/170] [SPARK-2479][MLlib] Comparing floating-point numbers using relative error in UnitTests Floating point math is not exact, and most floating-point numbers end up being slightly imprecise due to rounding errors. Simple values like 0.1 cannot be precisely represented using binary floating point numbers, and the limited precision of floating point numbers means that slight changes in the order of operations or the precision of intermediates can change the result. That means that comparing two floats to see if they are equal is usually not what we want. As long as this imprecision stays small, it can usually be ignored. Based on discussion in the community, we have implemented two different APIs for relative tolerance, and absolute tolerance. It makes sense that test writers should know which one they need depending on their circumstances. Developers also need to explicitly specify the eps, and there is no default value which will sometimes cause confusion. When comparing against zero using relative tolerance, a exception will be raised to warn users that it's meaningless. For relative tolerance, users can now write assert(23.1 ~== 23.52 relTol 0.02) assert(23.1 ~== 22.74 relTol 0.02) assert(23.1 ~= 23.52 relTol 0.02) assert(23.1 ~= 22.74 relTol 0.02) assert(!(23.1 !~= 23.52 relTol 0.02)) assert(!(23.1 !~= 22.74 relTol 0.02)) // This will throw exception with the following message. // "Did not expect 23.1 and 23.52 to be within 0.02 using relative tolerance." assert(23.1 !~== 23.52 relTol 0.02) // "Expected 23.1 and 22.34 to be within 0.02 using relative tolerance." assert(23.1 ~== 22.34 relTol 0.02) For absolute error, assert(17.8 ~== 17.99 absTol 0.2) assert(17.8 ~== 17.61 absTol 0.2) assert(17.8 ~= 17.99 absTol 0.2) assert(17.8 ~= 17.61 absTol 0.2) assert(!(17.8 !~= 17.99 absTol 0.2)) assert(!(17.8 !~= 17.61 absTol 0.2)) // This will throw exception with the following message. // "Did not expect 17.8 and 17.99 to be within 0.2 using absolute error." assert(17.8 !~== 17.99 absTol 0.2) // "Expected 17.8 and 17.59 to be within 0.2 using absolute error." assert(17.8 ~== 17.59 absTol 0.2) Authors: DB Tsai Marek Kolodziej Author: DB Tsai Closes #1425 from dbtsai/SPARK-2479_comparing_floating_point and squashes the following commits: 8c7cbcc [DB Tsai] Alpine Data Labs --- .../LogisticRegressionSuite.scala | 12 +- .../spark/mllib/clustering/KMeansSuite.scala | 63 +++--- .../evaluation/AreaUnderCurveSuite.scala | 13 +- .../BinaryClassificationMetricsSuite.scala | 40 ++-- .../optimization/GradientDescentSuite.scala | 16 +- .../spark/mllib/optimization/LBFGSSuite.scala | 17 +- .../spark/mllib/optimization/NNLSSuite.scala | 6 +- .../MultivariateOnlineSummarizerSuite.scala | 68 +++---- .../spark/mllib/util/TestingUtils.scala | 151 +++++++++++++-- .../spark/mllib/util/TestingUtilsSuite.scala | 182 ++++++++++++++++++ 10 files changed, 438 insertions(+), 130 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3f6ff859374c7..da7c633bbd2af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.01) + assert(model.intercept ~== 1.97 relTol 0.01) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 34bc4537a7b3a..afa1f79b95a12 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,8 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class KMeansSuite extends FunSuite with LocalSparkContext { @@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // centered at the mean of the points var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train( data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("no distinct points") { @@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext { var model = KMeans.train(data, k = 1, maxIterations = 1) assert(model.clusterCenters.size === 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) } test("single cluster with sparse data") { @@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext { val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0))) var model = KMeans.train(data, k = 1, maxIterations = 1) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 2) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL) - assert(model.clusterCenters.head === center) + assert(model.clusterCenters.head ~== center absTol 1E-5) data.unpersist() } test("k-means|| initialization") { + + case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { + @Override def compare(that: VectorWithCompare): Int = { + if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + } + } + val points = Seq( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), @@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext { // unselected point as long as it hasn't yet selected all of them var model = KMeans.train(rdd, k = 5, maxIterations = 1) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Iterations of Lloyd's should not change the answer either model = KMeans.train(rdd, k = 5, maxIterations = 10) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) // Neither should more runs model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5) - assert(Set(model.clusterCenters: _*) === Set(points: _*)) + assert(model.clusterCenters.sortBy(VectorWithCompare(_)) + .zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5)) } test("two clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 1c9844f289fe0..994e0feb8629e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 - assert(AreaUnderCurve.of(curve) === auc) + assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) == auc) + assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5) } test("auc of an empty curve") { val curve = Seq.empty[(Double, Double)] - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } test("auc of a curve with a single point") { val curve = Seq((1.0, 1.0)) - assert(AreaUnderCurve.of(curve) === 0.0) + assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5) val rddCurve = sc.parallelize(curve, 2) - assert(AreaUnderCurve.of(rddCurve) === 0.0) + assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 94db1dc183230..a733f88b60b80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals +import org.apache.spark.mllib.util.TestingUtils._ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { - // TODO: move utility functions to TestingUtils. + def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { - actual.zip(expected).forall { case (x1, x2) => - x1.almostEquals(x2) - } - } - - def elementsAlmostEqual( - actual: Seq[(Double, Double)], - expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { - actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => - x1.almostEquals(x2) && y1.almostEquals(y2) - } - } + def cond2(x: ((Double, Double), (Double, Double))): Boolean = + (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( @@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) val pr = recall.zip(precision) val prCurve = Seq((0.0, 1.0)) ++ pr - val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) - assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) - assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) - assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) - assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) - assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) - assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) - assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) + + assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) + assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) + assert(metrics.pr().collect().zip(prCurve).forall(cond2)) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) + assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) + assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) + assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) + assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index dfb2eb7f0d14e..bf040110e228b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - - assert(compareDouble( - loss1(0), - loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + - math.pow(initialWeightsWithIntercept(1), 2)) / 2), + assert( + loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5, """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") assert( - compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && - compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + (newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) && + (newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5), "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ff414742e8393..5f4c24115ac80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { @@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { lazy val dataRDD = sc.parallelize(data, 2).cache() - def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { - math.abs(x - y) / (math.abs(y) + 1e-15) < tol - } - test("LBFGS loss should be decreasing and match the result of Gradient Descent.") { val regParam = 0 @@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { miniBatchFrac, initialWeightsWithIntercept) - assert(compareDouble(lossGD(0), lossLBFGS(0)), + assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") // The 2% difference here is based on observation, but is not theoretically guaranteed. - assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02), + assert(lossGD.last ~= lossLBFGS.last relTol 0.02, "The last losses of LBFGS and GD should be within 2% difference.") - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } @@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { initialWeightsWithIntercept) // for class LBFGS and the optimize method, we only look at the weights - assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && - compareDouble(weightLBFGS(1), weightGD(1), 0.02), + assert( + (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bbf385229081a..b781a6aed9a8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -21,7 +21,9 @@ import scala.util.Random import org.scalatest.FunSuite -import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas} +import org.jblas.{DoubleMatrix, SimpleBlas} + +import org.apache.spark.mllib.util.TestingUtils._ class NNLSSuite extends FunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ @@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) val x = NNLS.solve(ata, atb, ws) for (i <- 0 until n) { - assert(Math.abs(x(i) - goodx(i)) < 1e-3) + assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 4b7b019d820b4..db13f142df517 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -89,15 +89,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(-1.0, 0.0, 6.0)) .add(Vectors.dense(3.0, -3.0, 0.0)) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -107,15 +107,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) - assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") assert(summarizer.count === 2) } @@ -129,17 +129,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { .add(Vectors.dense(1.7, -0.6, 0.0)) .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -157,17 +157,17 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer = summarizer1.merge(summarizer2) - assert(summarizer.mean.almostEquals( - Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + assert(summarizer.mean ~== + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") - assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") - assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") - assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer.variance.almostEquals( - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + assert(summarizer.variance ~== + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") assert(summarizer.count === 6) } @@ -186,24 +186,24 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer) assert(summarizer3.count === 0) - assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch") - assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch") - assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch") - assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch") - assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") - assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 64b1ba7527183..29cc42d8cbea7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -18,28 +18,155 @@ package org.apache.spark.mllib.util import org.apache.spark.mllib.linalg.Vector +import org.scalatest.exceptions.TestFailedException object TestingUtils { + val ABS_TOL_MSG = " using absolute tolerance" + val REL_TOL_MSG = " using relative tolerance" + + /** + * Private helper function for comparing two values using relative tolerance. + * Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue, + * the relative tolerance is meaningless, so the exception will be raised to warn users. + */ + private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + val absX = math.abs(x) + val absY = math.abs(y) + val diff = math.abs(x - y) + if (x == y) { + true + } else if (absX < Double.MinPositiveValue || absY < Double.MinPositiveValue) { + throw new TestFailedException( + s"$x or $y is extremely close to zero, so the relative tolerance is meaningless.", 0) + } else { + diff < eps * math.min(absX, absY) + } + } + + /** + * Private helper function for comparing two values using absolute tolerance. + */ + private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + math.abs(x - y) < eps + } + + case class CompareDoubleRightSide( + fun: (Double, Double, Double) => Boolean, y: Double, eps: Double, method: String) + + /** + * Implicit class for comparing two double values using relative tolerance or absolute tolerance. + */ implicit class DoubleWithAlmostEquals(val x: Double) { - // An improved version of AlmostEquals would always divide by the larger number. - // This will avoid the problem of diving by zero. - def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = { - if(x == y) { - true - } else if(math.abs(x) > math.abs(y)) { - math.abs(x - y) / math.abs(x) < epsilon - } else { - math.abs(x - y) / math.abs(y) < epsilon + + /** + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareDoubleRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two values are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareDoubleRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two values are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareDoubleRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method}.", 0) } + true } + + /** + * Throws exception when the difference of two values are within eps; otherwise, returns true. + */ + def !~==(r: CompareDoubleRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method}.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, + x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. + */ + def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, + x, eps, REL_TOL_MSG) + + override def toString = x.toString } + case class CompareVectorRightSide( + fun: (Vector, Vector, Double) => Boolean, y: Vector, eps: Double, method: String) + + /** + * Implicit class for comparing two vectors using relative tolerance or absolute tolerance. + */ implicit class VectorWithAlmostEquals(val x: Vector) { - def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = { - x.toArray.corresponds(y.toArray) { - _.almostEquals(_, epsilon) + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareVectorRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareVectorRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareVectorRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) } + true } + + /** + * Throws exception when the difference of two vectors are within eps; otherwise, returns true. + */ + def !~==(r: CompareVectorRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect $x and ${r.y} to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide( + (x: Vector, y: Vector, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala new file mode 100644 index 0000000000000..b0ecb33c28483 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.mllib.linalg.Vectors +import org.scalatest.FunSuite +import org.apache.spark.mllib.util.TestingUtils._ +import org.scalatest.exceptions.TestFailedException + +class TestingUtilsSuite extends FunSuite { + + test("Comparing doubles using relative error.") { + + assert(23.1 ~== 23.52 relTol 0.02) + assert(23.1 ~== 22.74 relTol 0.02) + assert(23.1 ~= 23.52 relTol 0.02) + assert(23.1 ~= 22.74 relTol 0.02) + assert(!(23.1 !~= 23.52 relTol 0.02)) + assert(!(23.1 !~= 22.74 relTol 0.02)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](23.1 !~== 23.52 relTol 0.02) + intercept[TestFailedException](23.1 !~== 22.74 relTol 0.02) + intercept[TestFailedException](23.1 ~== 23.63 relTol 0.02) + intercept[TestFailedException](23.1 ~== 22.34 relTol 0.02) + + assert(23.1 !~== 23.63 relTol 0.02) + assert(23.1 !~== 22.34 relTol 0.02) + assert(23.1 !~= 23.63 relTol 0.02) + assert(23.1 !~= 22.34 relTol 0.02) + assert(!(23.1 ~= 23.63 relTol 0.02)) + assert(!(23.1 ~= 22.34 relTol 0.02)) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException](0.1 ~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 ~= 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~== 0.0 relTol 0.032) + intercept[TestFailedException](0.1 !~= 0.0 relTol 0.032) + intercept[TestFailedException](0.0 ~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 ~= 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~== 0.1 relTol 0.032) + intercept[TestFailedException](0.0 !~= 0.1 relTol 0.032) + + // Comparisons of numbers very close to zero. + assert(10 * Double.MinPositiveValue ~== 9.5 * Double.MinPositiveValue relTol 0.01) + assert(10 * Double.MinPositiveValue !~== 11 * Double.MinPositiveValue relTol 0.01) + + assert(-Double.MinPositiveValue ~== 1.18 * -Double.MinPositiveValue relTol 0.012) + assert(-Double.MinPositiveValue ~== 1.38 * -Double.MinPositiveValue relTol 0.012) + } + + test("Comparing doubles using absolute error.") { + + assert(17.8 ~== 17.99 absTol 0.2) + assert(17.8 ~== 17.61 absTol 0.2) + assert(17.8 ~= 17.99 absTol 0.2) + assert(17.8 ~= 17.61 absTol 0.2) + assert(!(17.8 !~= 17.99 absTol 0.2)) + assert(!(17.8 !~= 17.61 absTol 0.2)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](17.8 !~== 17.99 absTol 0.2) + intercept[TestFailedException](17.8 !~== 17.61 absTol 0.2) + intercept[TestFailedException](17.8 ~== 18.01 absTol 0.2) + intercept[TestFailedException](17.8 ~== 17.59 absTol 0.2) + + assert(17.8 !~== 18.01 absTol 0.2) + assert(17.8 !~== 17.59 absTol 0.2) + assert(17.8 !~= 18.01 absTol 0.2) + assert(17.8 !~= 17.59 absTol 0.2) + assert(!(17.8 ~= 18.01 absTol 0.2)) + assert(!(17.8 ~= 17.59 absTol 0.2)) + + // Comparisons of numbers very close to zero, and both side of zeros + assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + } + + test("Comparing vectors using relative error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + assert(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + assert(!(Vectors.dense(Array(3.1, 3.5)) !~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01)) + assert(!(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.135, 3.534)) relTol 0.01)) + + // Should throw exception with message when test fails. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) + + // Comparing against zero should fail the test and throw exception with message + // saying that the relative error is meaningless in this situation. + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.dense(Array(3.13, 0.0)) relTol 0.01) + + intercept[TestFailedException]( + Vectors.dense(Array(3.1, 0.01)) ~== Vectors.sparse(2, Array(0), Array(3.13)) relTol 0.01) + + // Comparisons of two sparse vectors + assert(Vectors.dense(Array(3.1, 3.5)) ~== + Vectors.sparse(2, Array(0, 1), Array(3.130, 3.534)) relTol 0.01) + + assert(Vectors.dense(Array(3.1, 3.5)) !~== + Vectors.sparse(2, Array(0, 1), Array(3.135, 3.534)) relTol 0.01) + } + + test("Comparing vectors using absolute error.") { + + //Comparisons of two dense vectors + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) !~= + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6)) + + assert(!(Vectors.dense(Array(3.1, 3.5, 0.0)) ~= + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6)) + + // Should throw exception with message when test fails. + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) !~== + Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) + + intercept[TestFailedException](Vectors.dense(Array(3.1, 3.5, 0.0)) ~== + Vectors.dense(Array(3.1 + 1E-5, 3.5 + 2E-7, 1 + 1E-3)) absTol 1E-6) + + // Comparisons of two sparse vectors + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-8, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1 + 1E-3, 2.4)) !~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + // Comparisons of a dense vector and a sparse vector + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) ~== + Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) absTol 1E-6) + + assert(Vectors.dense(Array(3.1 + 1E-8, 0, 2.4 + 1E-7)) ~== + Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) absTol 1E-6) + + assert(Vectors.sparse(3, Array(0, 2), Array(3.1, 2.4)) !~== + Vectors.dense(Array(3.1, 1E-3, 2.4)) absTol 1E-6) + } +} From a7a9d14479ea6421513a962ff0f45cb969368bab Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 28 Jul 2014 12:07:30 -0700 Subject: [PATCH 003/170] [SPARK-2410][SQL] Merging Hive Thrift/JDBC server (with Maven profile fix) JIRA issue: [SPARK-2410](https://issues.apache.org/jira/browse/SPARK-2410) Another try for #1399 & #1600. Those two PR breaks Jenkins builds because we made a separate profile `hive-thriftserver` in sub-project `assembly`, but the `hive-thriftserver` module is defined outside the `hive-thriftserver` profile. Thus every time a pull request that doesn't touch SQL code will also execute test suites defined in `hive-thriftserver`, but tests fail because related .class files are not included in the assembly jar. In the most recent commit, module `hive-thriftserver` is moved into its own profile to fix this problem. All previous commits are squashed for clarity. Author: Cheng Lian Closes #1620 from liancheng/jdbc-with-maven-fix and squashes the following commits: 629988e [Cheng Lian] Moved hive-thriftserver module definition into its own profile ec3c7a7 [Cheng Lian] Cherry picked the Hive Thrift server --- .gitignore | 1 + assembly/pom.xml | 10 + bagel/pom.xml | 2 +- bin/beeline | 45 +++ bin/compute-classpath.sh | 1 + bin/spark-shell | 4 +- bin/spark-shell.cmd | 2 +- bin/spark-sql | 36 ++ core/pom.xml | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 14 +- .../spark/deploy/SparkSubmitArguments.scala | 5 +- dev/create-release/create-release.sh | 10 +- dev/run-tests | 2 +- dev/scalastyle | 2 +- docs/sql-programming-guide.md | 201 +++++++++- examples/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 16 +- project/SparkBuild.scala | 14 +- sbin/start-thriftserver.sh | 36 ++ sql/catalyst/pom.xml | 2 +- .../sql/catalyst/plans/logical/commands.scala | 3 +- sql/core/pom.xml | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +- .../apache/spark/sql/execution/commands.scala | 42 ++- .../org/apache/spark/sql/SQLConfSuite.scala | 13 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- sql/hive-thriftserver/pom.xml | 82 +++++ .../hive/thriftserver/HiveThriftServer2.scala | 97 +++++ .../hive/thriftserver/ReflectionUtils.scala | 58 +++ .../hive/thriftserver/SparkSQLCLIDriver.scala | 344 ++++++++++++++++++ .../thriftserver/SparkSQLCLIService.scala | 74 ++++ .../hive/thriftserver/SparkSQLDriver.scala | 93 +++++ .../sql/hive/thriftserver/SparkSQLEnv.scala | 58 +++ .../thriftserver/SparkSQLSessionManager.scala | 49 +++ .../server/SparkSQLOperationManager.scala | 151 ++++++++ .../test/resources/data/files/small_kv.txt | 5 + .../sql/hive/thriftserver/CliSuite.scala | 57 +++ .../thriftserver/HiveThriftServer2Suite.scala | 135 +++++++ .../sql/hive/thriftserver/TestUtils.scala | 108 ++++++ sql/hive/pom.xml | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 50 ++- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/alpha/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 54 files changed, 1790 insertions(+), 96 deletions(-) create mode 100755 bin/beeline create mode 100755 bin/spark-sql create mode 100755 sbin/start-thriftserver.sh create mode 100644 sql/hive-thriftserver/pom.xml create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala create mode 100755 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala create mode 100644 sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala diff --git a/.gitignore b/.gitignore index a4ec12ca6b53f..7ec8d45e12c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,4 @@ metastore_db/ metastore/ warehouse/ TempStatsStore/ +sql/hive-thriftserver/test_warehouses diff --git a/assembly/pom.xml b/assembly/pom.xml index 567a8dd2a0d94..703f15925bc44 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -165,6 +165,16 @@ + + hive-thriftserver + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bagel/pom.xml b/bagel/pom.xml index 90c4b095bb611..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-bagel_2.10 - bagel + bagel jar Spark Project Bagel diff --git a/bin/beeline b/bin/beeline new file mode 100755 index 0000000000000..09fe366c609fa --- /dev/null +++ b/bin/beeline @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +classpath_output=$($FWDIR/bin/compute-classpath.sh) +if [[ "$?" != "0" ]]; then + echo "$classpath_output" + exit 1 +else + CLASSPATH=$classpath_output +fi + +CLASS="org.apache.hive.beeline.BeeLine" +exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index e81e8c060cb98..16b794a1592e8 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" fi diff --git a/bin/spark-shell b/bin/spark-shell index 850e9507ec38f..756c8179d12b6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -46,11 +46,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 4b9708a8c03f3..b56d69801171c 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql new file mode 100755 index 0000000000000..bba7f897b19bc --- /dev/null +++ b/bin/spark-sql @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/spark-sql [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/core/pom.xml b/core/pom.xml index 1054cec4d77bb..a24743495b0e1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-core_2.10 - core + core jar Spark Project Core diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3b5642b6caa36..c9cec33ebaa66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -46,6 +46,10 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + // A special jar name that indicates the class being run is inside of Spark itself, and therefore + // no user jar is needed. + private val SPARK_INTERNAL = "spark-internal" + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -257,7 +261,9 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (clusterManager == YARN && deployMode == CLUSTER) { childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } @@ -332,7 +338,7 @@ object SparkSubmit { * Return whether the given primary resource represents a user jar. */ private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) + !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) } /** @@ -349,6 +355,10 @@ object SparkSubmit { primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL } + private[spark] def isInternal(primaryResource: String): Boolean = { + primaryResource == SPARK_INTERNAL + } + /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3ab67a43a3b55..01d0ae541a66b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -204,8 +204,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - // Delineates parsing of Spark options from parsing of user options. var inSparkOpts = true + + // Delineates parsing of Spark options from parsing of user options. parse(opts) def parse(opts: Seq[String]): Unit = opts match { @@ -318,7 +319,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { SparkSubmit.printErrorAndExit(errMessage) case v => primaryResource = - if (!SparkSubmit.isShell(v)) { + if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { Utils.resolveURI(v).toString } else { v diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 38830103d1e8d..33de24d1ae6d7 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare @@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ release:perform cd .. @@ -111,10 +111,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" make_binary_release "hadoop2" \ - "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/run-tests b/dev/run-tests index 51e4def0f835a..98ec969dc1b37 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -65,7 +65,7 @@ echo "=========================================================================" # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ + echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \ assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" else echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ diff --git a/dev/scalastyle b/dev/scalastyle index a02d06912f238..d9f2b91a3a091 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38728534a46e0..156e0aebdebe6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD // Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, +// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, // you can use custom classes that implement the Product interface. case class Person(name: String, age: Int) @@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect() - # Writing Language-Integrated Relational Queries **Language-Integrated queries are currently only supported in Scala.** @@ -573,4 +572,200 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres evaluated by the SQL execution engine. A full list of the functions supported can be found in the [ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - \ No newline at end of file + + +## Running the Thrift JDBC server + +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] +(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test +the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive +you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` +for maven). + +To start the JDBC server, run the following in the Spark directory: + + ./sbin/start-thriftserver.sh + +The default port the server listens on is 10000. To listen on customized host and port, please set +the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may +run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can +use beeline to test the Thrift JDBC server: + + ./bin/beeline + +Connect to the JDBC server in beeline with: + + beeline> !connect jdbc:hive2://localhost:10000 + +Beeline will ask you for a username and password. In non-secure mode, simply enter the username on +your machine and a blank password. For secure mode, please follow the instructions given in the +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +You may also use the beeline script comes with Hive. + +### Migration Guide for Shark Users + +#### Reducer number + +In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark +SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +is 200. Users may customize this property via `SET`: + +``` +SET spark.sql.shuffle.partitions=10; +SELECT page, count(*) c FROM logs_last_month_cached +GROUP BY page ORDER BY c DESC LIMIT 10; +``` + +You may also put this property in `hive-site.xml` to override the default value. + +For now, the `mapred.reduce.tasks` property is still recognized, and is converted to +`spark.sql.shuffle.partitions` automatically. + +#### Caching + +The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no +longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +let user control table caching explicitly: + +``` +CACHE TABLE logs_last_month; +UNCACHE TABLE logs_last_month; +``` + +**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", +but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be +cached, you may simply count the table immediately after executing `CACHE TABLE`: + +``` +CACHE TABLE logs_last_month; +SELECT COUNT(1) FROM logs_last_month; +``` + +Several caching related features are not supported yet: + +* User defined partition level cache eviction policy +* RDD reloading +* In-memory cache write through policy + +### Compatibility with Apache Hive + +#### Deploying in Exising Hive Warehouses + +Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +installations. You do not need to modify your existing Hive Metastore or change the data placement +or partitioning of your tables. + +#### Supported Hive Features + +Spark SQL supports the vast majority of Hive features, such as: + +* Hive query statements, including: + * `SELECT` + * `GROUP BY + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` +* All Hive operators, including: + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathemtatical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) +* User defined functions (UDF) +* User defined aggregation functions (UDAF) +* User defined serialization formats (SerDe's) +* Joins + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` +* Unions +* Sub queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sampling +* Explain +* Partitioned tables +* All Hive DDL Functions, including: + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` +* Most Hive Data types, including: + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +#### Unsupported Hive Functionality + +Below is a list of Hive features that we don't support yet. Most of these features are rarely used +in Hive deployments. + +**Major Hive Features** + +* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL + doesn't support buckets yet. + +**Esoteric Hive Features** + +* Tables with partitions using different input formats: In Spark SQL, all table partitions need to + have the same input format. +* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions + (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. +* `UNIONTYPE` +* Unique join +* Single query multi insert +* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at + the moment. + +**Hive Input/Output Formats** + +* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat. +* Hadoop archive + +**Hive Optimizations** + +A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are +not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +releases of Spark SQL. + +* Block level bitmap indexes and virtual columns (used to build indexes) +* Automatically convert a join to map join: For joining a large table with multiple small tables, + Hive automatically converts the join into a map join. We are adding this auto conversion in the + next release. +* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you + need to control the degree of parallelism post-shuffle using "SET + spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the + next release. +* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still + launches tasks to compute the result. +* Skew data flag: Spark SQL does not follow the skew data flags in Hive. +* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. +* Merge multiple small files for query results: if the result output contains multiple small files, + Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS + metadata. Spark SQL does not support that. + +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. diff --git a/examples/pom.xml b/examples/pom.xml index bd1c387c2eb91..c4ed0f5a6a02b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-examples_2.10 - examples + examples jar Spark Project Examples diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 61a6aff543aed..874b8a7959bb6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-flume_2.10 - streaming-flume + streaming-flume jar Spark Project External Flume diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4762c50685a93..25a5c0a4d7d77 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-kafka_2.10 - streaming-kafka + streaming-kafka jar Spark Project External Kafka diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 32c530e600ce0..f31ed655f6779 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-mqtt_2.10 - streaming-mqtt + streaming-mqtt jar Spark Project External MQTT diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 637adb0f00da0..56bb24c2a072e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-twitter_2.10 - streaming-twitter + streaming-twitter jar Spark Project External Twitter diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e4d758a04a4cd..54b0242c54e78 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-zeromq_2.10 - streaming-zeromq + streaming-zeromq jar Spark Project External ZeroMQ diff --git a/graphx/pom.xml b/graphx/pom.xml index 7e3bcf29dcfbc..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-graphx_2.10 - graphx + graphx jar Spark Project GraphX diff --git a/mllib/pom.xml b/mllib/pom.xml index 92b07e2357db1..f27cf520dc9fa 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-mllib_2.10 - mllib + mllib jar Spark Project ML Library diff --git a/pom.xml b/pom.xml index d2e6b3c0ed5a4..93ef3b91b5bce 100644 --- a/pom.xml +++ b/pom.xml @@ -252,9 +252,9 @@ 3.3.2 - commons-codec - commons-codec - 1.5 + commons-codec + commons-codec + 1.5 com.google.code.findbugs @@ -1139,5 +1139,15 @@ + + hive-thriftserver + + false + + + sql/hive-thriftserver + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 62576f84dd031..1629bc2cba8ba 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, - streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", - "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql, + streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", + "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") @@ -100,7 +100,7 @@ object SparkBuild extends PomBuild { Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) - case _ => + case _ => } override val userPropertiesMap = System.getProperties.toMap @@ -158,7 +158,7 @@ object SparkBuild extends PomBuild { /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)). foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh new file mode 100755 index 0000000000000..8398e6f19b511 --- /dev/null +++ b/sbin/start-thriftserver.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL Thrift server + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-thriftserver [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62d..531bfddbf237b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -32,7 +32,7 @@ Spark Project Catalyst http://spark.apache.org/ - catalyst + catalyst diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d274..a357c6ffb8977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + BoundReference(1, AttributeReference("", StringType, nullable = false)())) } /** diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c309c43804d97..3a038a2db6173 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -32,7 +32,7 @@ Spark Project SQL http://spark.apache.org/ - sql + sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2b787e14f3f15..41920c00b5a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,12 +30,13 @@ import scala.collection.JavaConverters._ * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). */ trait SQLConf { + import SQLConf._ /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -43,11 +44,10 @@ trait SQLConf { * effectively disables auto conversion. * Hive setting: hive.auto.convert.join.noconditionaltask.size. */ - private[spark] def autoConvertJoinSize: Int = - get("spark.sql.auto.convert.join.size", "10000").toInt + private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "") /** ********************** SQLConf functionality methods ************ */ @@ -61,7 +61,7 @@ trait SQLConf { def set(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for ${key}") + require(value != null, s"value cannot be null for $key") settings.put(key, value) } @@ -90,3 +90,13 @@ trait SQLConf { } } + +object SQLConf { + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 98d2f89c8ae71..9293239131d52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { /** @@ -44,28 +45,53 @@ trait Command { case class SetCommand( key: Option[String], value: Option[String], output: Seq[Attribute])( @transient context: SQLContext) - extends LeafNode with Command { + extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => - context.set(k, v) - Array(k -> v) + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") + context.set(SQLConf.SHUFFLE_PARTITIONS, v) + Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + } else { + context.set(k, v) + Array(s"$k=$v") + } // Query the value bound to key k. case (Some(k), _) => - Array(k -> context.getOption(k).getOrElse("")) + // TODO (lian) This is just a workaround to make the Simba ODBC driver work. + // Should remove this once we get the ODBC driver updated. + if (k == "-v") { + val hiveJars = Seq( + "hive-exec-0.12.0.jar", + "hive-service-0.12.0.jar", + "hive-common-0.12.0.jar", + "hive-hwi-0.12.0.jar", + "hive-0.12.0.jar").mkString(":") + + Array( + "system:java.class.path=" + hiveJars, + "system:sun.java.command=shark.SharkServer2") + } + else { + Array(s"$k=${context.getOption(k).getOrElse("")}") + } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll + context.getAll.map { case (k, v) => + s"$k=$v" + } case _ => throw new IllegalArgumentException() } def execute(): RDD[Row] = { - val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } context.sparkContext.parallelize(rows, 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 08293f7f0ca30..1a58d73d9e7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest { assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(get("some.property", "0") == "20") + sql("set some.property = 40") + assert(get("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" @@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest { clear() } + test("deprecated property") { + clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6736189c96d4b..de9e8aa4f62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) + Seq(Seq(s"$nonexistentKey=")) ) clear() } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000000..7fac90fdc596d --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive-thriftserver_2.10 + jar + Spark Project Hive + http://spark.apache.org/ + + hive-thriftserver + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + org.spark-project.hive + hive-cli + ${hive.version} + + + org.spark-project.hive + hive-jdbc + ${hive.version} + + + org.spark-project.hive + hive-beeline + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000000..ddbc2a79fb512 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logger.warn("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logger.debug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logger.info("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logger.info("HiveThriftServer2 started") + } catch { + case e: Exception => + logger.error("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000..599294dfbb7d7 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000000..27268ecb923e9 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.sql.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + ret = driver.run(cmd).getResponseCode + if (ret != 0) { + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000000..42cbf363b274f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000..5202aa9903e03 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logger.debug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + val execution = context.executePlan(context.hql(command).logicalPlan) + + // TODO unify the error code + try { + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logger.error(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000..451c3bd7b9352 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logger.debug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logger.debug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..6b3275b4eaf04 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000000..a4e1f3e762e89 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logger.debug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logger.info(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.hql(statement) + logger.debug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logger.error("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000000..850f8014b6f05 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000000..69f19f826a802 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000000..fe3403b3292ec --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.net.ServerSocket +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + val HOST = "localhost" + val PORT = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + val environment = pb.environment() + environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) + environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000000..bb2242618fbef --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 1699ffe06ce15..93d00f7c37c9b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -32,7 +32,7 @@ Spark Project Hive http://spark.apache.org/ - hive + hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 201c85f3d501e..84d43eaeea51d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -255,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DecimalType, TimestampType, BinaryType) - protected def toHiveString(a: (Any, DataType)): String = a match { + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a8623b64c656f..a022a1e2dc70e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -419,10 +419,10 @@ class HiveQuerySuite extends HiveComparisonTest { hql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - hql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + hql("set some.property=20") + assert(get("some.property", "0") == "20") + hql("set some.property = 40") + assert(get("some.property", "0") == "40") hql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) @@ -436,63 +436,61 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. assert(hql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } hql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + hql(s"SET").collect().map(_.getString(0)) } // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + hql(s"SET $nonexistentKey").collect().map(_.getString(0)) } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql("SET").collect().map(_.getString(0)) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql("SET").collect().map(_.getString(0)) } - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } clear() diff --git a/streaming/pom.xml b/streaming/pom.xml index f60697ce745b7..b99f306b8f2cc 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming_2.10 - streaming + streaming jar Spark Project Streaming diff --git a/tools/pom.xml b/tools/pom.xml index c0ee8faa7a615..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -27,7 +27,7 @@ org.apache.spark spark-tools_2.10 - tools + tools jar Spark Project Tools diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 5b13a1f002d6e..51744ece0412d 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-alpha + yarn-alpha org.apache.spark diff --git a/yarn/pom.xml b/yarn/pom.xml index efb473aa1b261..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -29,7 +29,7 @@ pom Spark Project YARN Parent POM - yarn + yarn diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index ceaf9f9d71001..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-stable + yarn-stable org.apache.spark From 39ab87b924ad65b6b9b7aa6831f3e9ddc2b76dd7 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 28 Jul 2014 13:37:44 -0700 Subject: [PATCH 004/170] Use commons-lang3 in SignalLogger rather than commons-lang Spark only transitively depends on the latter, based on the Hadoop version. Author: Aaron Davidson Closes #1621 from aarondav/lang3 and squashes the following commits: 93c93bf [Aaron Davidson] Use commons-lang3 in SignalLogger rather than commons-lang --- core/src/main/scala/org/apache/spark/util/SignalLogger.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala index d769b54fa2fae..f77488ef3d449 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import org.apache.commons.lang.SystemUtils +import org.apache.commons.lang3.SystemUtils import org.slf4j.Logger import sun.misc.{Signal, SignalHandler} From 16ef4d110f15dfe66852802fdadfe2ed7574ddc2 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Mon, 28 Jul 2014 21:39:02 -0700 Subject: [PATCH 005/170] Excess judgment Author: Yadong Qi Closes #1629 from watermen/bug-fix2 and squashes the following commits: 59b7237 [Yadong Qi] Update HiveQl.scala --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6ab68b563f8d..d18ccf8167487 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -610,7 +610,7 @@ private[hive] object HiveQl { // TOK_DESTINATION means to overwrite the table. val resultDestination = (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = if (intoClause.isEmpty) true else false + val overwrite = intoClause.isEmpty nodeToDest( resultDestination, withLimit, From ccd5ab5f82812abc2eb518448832cc20fb903345 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 29 Jul 2014 00:15:45 -0700 Subject: [PATCH 006/170] [SPARK-2580] [PySpark] keep silent in worker if JVM close the socket During rdd.take(n), JVM will close the socket if it had got enough data, the Python worker should keep silent in this case. In the same time, the worker should not print the trackback into stderr if it send the traceback to JVM successfully. Author: Davies Liu Closes #1625 from davies/error and squashes the following commits: 4fbcc6d [Davies Liu] disable log4j during testing when exception is expected. cc14202 [Davies Liu] keep silent in worker if JVM close the socket --- python/pyspark/tests.py | 6 ++++++ python/pyspark/worker.py | 21 +++++++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc5e9ad96fa..6dee7dc66cee6 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -165,11 +165,17 @@ class TestAddFile(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) def func(x): from userlibrary import UserClass return UserClass().hello() self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + log4j.LogManager.getRootLogger().setLevel(old_level) + # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") self.sc.addPyFile(path) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 24d41b12d1b1a..2770f63059853 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -75,14 +75,19 @@ def main(infile, outfile): init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - except Exception as e: - # Write the error to stderr in addition to trying to pass it back to - # Java, in case it happened while serializing a record - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() - write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) - sys.exit(-1) + except Exception: + try: + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) + write_with_length(traceback.format_exc(), outfile) + outfile.flush() + except IOError: + # JVM close the socket + pass + except Exception: + # Write the error to stderr if it happened while serializing + print >> sys.stderr, "PySpark worker failed with exception:" + print >> sys.stderr, traceback.format_exc() + exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output From 92ef02626e793ea853cced4cbfee316f0b748ed7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 29 Jul 2014 01:02:18 -0700 Subject: [PATCH 007/170] [SPARK-791] [PySpark] fix pickle itemgetter with cloudpickle fix the problem with pickle operator.itemgetter with multiple index. Author: Davies Liu Closes #1627 from davies/itemgetter and squashes the following commits: aabd7fa [Davies Liu] fix pickle itemgetter with cloudpickle --- python/pyspark/cloudpickle.py | 5 +++-- python/pyspark/tests.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 4fda2a9b950b8..68062483dedaa 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -560,8 +560,9 @@ class ItemGetterType(ctypes.Structure): ] - itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,)) + obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents + return self.save_reduce(operator.itemgetter, + obj.item if obj.nitems > 1 else (obj.item,)) if PyObject_HEAD: dispatch[operator.itemgetter] = save_itemgetter diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6dee7dc66cee6..8486c8595b5a4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -284,6 +284,12 @@ def combOp(x, y): self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + class TestIO(PySparkTestCase): From 96ba04bbf917bcb971dd0d8cd1e1766dbe9366e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 29 Jul 2014 01:12:44 -0700 Subject: [PATCH 008/170] [SPARK-2726] and [SPARK-2727] Remove SortOrder and do in-place sort. The pull request includes two changes: 1. Removes SortOrder introduced by SPARK-2125. The key ordering already includes the SortOrder information since an Ordering can be reverse. This is similar to Java's Comparator interface. Rarely does an API accept both a Comparator as well as a SortOrder. 2. Replaces the sortWith call in HashShuffleReader with an in-place quick sort. Author: Reynold Xin Closes #1631 from rxin/sortOrder and squashes the following commits: c9d37e1 [Reynold Xin] [SPARK-2726] and [SPARK-2727] Remove SortOrder and do in-place sort. --- .../scala/org/apache/spark/Dependency.scala | 4 +-- .../spark/rdd/OrderedRDDFunctions.scala | 8 +----- .../org/apache/spark/rdd/ShuffledRDD.scala | 12 +-------- .../shuffle/hash/HashShuffleReader.scala | 25 +++++++++++-------- 4 files changed, 18 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index f010c03223ef4..09a60571238ea 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -19,7 +19,6 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -63,8 +62,7 @@ class ShuffleDependency[K, V, C]( val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, - val mapSideCombine: Boolean = false, - val sortOrder: Option[SortOrder] = None) + val mapSideCombine: Boolean = false) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index afd7075f686b9..d85f962783931 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -58,12 +58,6 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) new ShuffledRDD[K, V, V, P](self, part) - .setKeyOrdering(ordering) - .setSortOrder(if (ascending) SortOrder.ASCENDING else SortOrder.DESCENDING) + .setKeyOrdering(if (ascending) ordering else ordering.reverse) } } - -private[spark] object SortOrder extends Enumeration { - type SortOrder = Value - val ASCENDING, DESCENDING = Value -} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index da4a8c3dc22b1..bf02f68d0d3d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.SortOrder.SortOrder import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -52,8 +51,6 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false - private var sortOrder: Option[SortOrder] = None - /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { this.serializer = Option(serializer) @@ -78,15 +75,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( this } - /** Set sort order for RDD's sorting. */ - def setSortOrder(sortOrder: SortOrder): ShuffledRDD[K, V, C, P] = { - this.sortOrder = Option(sortOrder) - this - } - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializer, - keyOrdering, aggregator, mapSideCombine, sortOrder)) + List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } override val partitioner = Some(part) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 76cdb8f4f8e8a..c8059496a1bdf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -18,7 +18,6 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.rdd.SortOrder import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -51,16 +50,22 @@ class HashShuffleReader[K, C]( iter } - val sortedIter = for (sortOrder <- dep.sortOrder; ordering <- dep.keyOrdering) yield { - val buf = aggregatedIter.toArray - if (sortOrder == SortOrder.ASCENDING) { - buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator - } else { - buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator - } + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Define a Comparator for the whole record based on the key Ordering. + val cmp = new Ordering[Product2[K, C]] { + override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = { + keyOrd.compare(o1._1, o2._1) + } + } + val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray + // TODO: do external sort. + scala.util.Sorting.quickSort(sortBuffer)(cmp) + sortBuffer.iterator + case None => + aggregatedIter } - - sortedIter.getOrElse(aggregatedIter) } /** Close this reader */ From 20424dad30e6c89ba42b07eb329070bdcb3494cb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 29 Jul 2014 01:16:41 -0700 Subject: [PATCH 009/170] [SPARK-2174][MLLIB] treeReduce and treeAggregate In `reduce` and `aggregate`, the driver node spends linear time on the number of partitions. It becomes a bottleneck when there are many partitions and the data from each partition is big. SPARK-1485 (#506) tracks the progress of implementing AllReduce on Spark. I did several implementations including butterfly, reduce + broadcast, and treeReduce + broadcast. treeReduce + BT broadcast seems to be right way to go for Spark. Using binary tree may introduce some overhead in communication, because the driver still need to coordinate on data shuffling. In my experiments, n -> sqrt(n) -> 1 gives the best performance in general, which is why I set "depth = 2" in MLlib algorithms. But it certainly needs more testing. I left `treeReduce` and `treeAggregate` public for easy testing. Some numbers from a test on 32-node m3.2xlarge cluster. code: ~~~ import breeze.linalg._ import org.apache.log4j._ Logger.getRootLogger.setLevel(Level.OFF) for (n <- Seq(1, 10, 100, 1000, 10000, 100000, 1000000)) { val vv = sc.parallelize(0 until 1024, 1024).map(i => DenseVector.zeros[Double](n)) var start = System.nanoTime(); vv.treeReduce(_ + _, 2); println((System.nanoTime() - start) / 1e9) start = System.nanoTime(); vv.reduce(_ + _); println((System.nanoTime() - start) / 1e9) } ~~~ out: | n | treeReduce(,2) | reduce | |---|---------------------|-----------| | 10 | 0.215538731 | 0.204206899 | | 100 | 0.278405907 | 0.205732582 | | 1000 | 0.208972182 | 0.214298272 | | 10000 | 0.194792071 | 0.349353687 | | 100000 | 0.347683285 | 6.086671892 | | 1000000 | 2.589350682 | 66.572906702 | CC: @pwendell This is clearly more scalable than the default implementation. My question is whether we should use this implementation in `reduce` and `aggregate` or put them as separate methods. The concern is that users may use `reduce` and `aggregate` as collect, where having multiple stages doesn't reduce the data size. However, in this case, `collect` is more appropriate. Author: Xiangrui Meng Closes #1110 from mengxr/tree and squashes the following commits: c6cd267 [Xiangrui Meng] make depth default to 2 b04b96a [Xiangrui Meng] address comments 9bcc5d3 [Xiangrui Meng] add depth for readability 7495681 [Xiangrui Meng] fix compile error 142a857 [Xiangrui Meng] merge master d58a087 [Xiangrui Meng] move treeReduce and treeAggregate to mllib 8a2a59c [Xiangrui Meng] Merge branch 'master' into tree be6a88a [Xiangrui Meng] use treeAggregate in mllib 0f94490 [Xiangrui Meng] add docs eb71c33 [Xiangrui Meng] add treeReduce fe42a5e [Xiangrui Meng] add treeAggregate --- .../mllib/linalg/distributed/RowMatrix.scala | 23 +++---- .../mllib/optimization/GradientDescent.scala | 3 +- .../spark/mllib/optimization/LBFGS.scala | 3 +- .../apache/spark/mllib/rdd/RDDFunctions.scala | 66 +++++++++++++++++++ .../spark/mllib/rdd/RDDFunctionsSuite.scala | 18 +++++ 5 files changed, 98 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8c2b044ea73f2..58c1322757a43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} /** @@ -79,7 +80,7 @@ class RowMatrix( private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { val n = numCols().toInt val vbr = rows.context.broadcast(v) - rows.aggregate(BDV.zeros[Double](n))( + rows.treeAggregate(BDV.zeros[Double](n))( seqOp = (U, r) => { val rBrz = r.toBreeze val a = rBrz.dot(vbr.value) @@ -91,9 +92,7 @@ class RowMatrix( s"Do not support vector operation from type ${rBrz.getClass.getName}.") } U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) } /** @@ -104,13 +103,11 @@ class RowMatrix( val nt: Int = n * (n + 1) / 2 // Compute the upper triangular part of the gram matrix. - val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))( + val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { RowMatrix.dspr(1.0, v, U.data) U - }, - combOp = (U1, U2) => U1 += U2 - ) + }, combOp = (U1, U2) => U1 += U2) RowMatrix.triuToFull(n, GU.data) } @@ -290,9 +287,10 @@ class RowMatrix( s"We need at least $mem bytes of memory.") } - val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( + val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), - combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) + combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => + (s1._1 + s2._1, s1._2 += s2._2) ) updateNumRows(m) @@ -353,10 +351,9 @@ class RowMatrix( * Computes column-wise summary statistics. */ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { - val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)( + val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), - (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - ) + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) updateNumRows(summary.count) summary } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 9fd760bf78083..356aa949afcf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. @@ -177,7 +178,7 @@ object GradientDescent extends Logging { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](n), 0.0))( + .treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) (grad, loss + l) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 179cd4a3f1625..26a2b62e76ed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.rdd.RDDFunctions._ /** * :: DeveloperApi :: @@ -199,7 +200,7 @@ object LBFGS extends Logging { val n = weights.length val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 365b5e75d7f75..b5e403bc8c14d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.HashPartitioner +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -44,6 +47,69 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { new SlidingRDD[T](self, windowSize) } } + + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = self.context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (self.partitions.size == 0) { + return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = self.context.clean(seqOp) + val cleanCombOp = self.context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } } private[mllib] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 3f3b10dfff35e..27a19f793242b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val expected = data.flatMap(x => x).sliding(3).toList assert(sliding.collect().toList === expected) } + + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } } From fc4d05700026f4ee9cc5544cf493d900039c38f3 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Tue, 29 Jul 2014 01:35:26 -0700 Subject: [PATCH 010/170] Minor indentation and comment typo fixes. Author: Aaron Staple Closes #1630 from staple/minor and squashes the following commits: 6f295a2 [Aaron Staple] Fix typos in comment about ExprId. 8566467 [Aaron Staple] Fix off by one column indentation in SqlParser. --- .../apache/spark/sql/catalyst/SqlParser.scala | 22 +++++++++---------- .../expressions/namedExpressions.scala | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a34b236c8ac6a..2c73a80f64ebf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -210,21 +210,21 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } | "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { + protected lazy val joinedRelation: Parser[LogicalPlan] = + relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { case r1 ~ jt ~ _ ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) - } + } - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression + protected lazy val joinConditions: Parser[Expression] = + ON ~> expression - protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter + protected lazy val joinType: Parser[JoinType] = + INNER ^^^ Inner | + LEFT ~ SEMI ^^^ LeftSemi | + LEFT ~ opt(OUTER) ^^^ LeftOuter | + RIGHT ~ opt(OUTER) ^^^ RightOuter | + FULL ~ opt(OUTER) ^^^ FullOuter protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 934bad8c27294..ed69928ae9eb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -28,8 +28,8 @@ object NamedExpression { } /** - * A globally (within this JVM) id for a given named expression. - * Used to identify with attribute output by a relation is being + * A globally unique (within this JVM) id for a given named expression. + * Used to identify which attribute output by a relation is being * referenced in a subsequent computation. */ case class ExprId(id: Long) From 800ecff4b1127d9042d5a8a746348fb4d45aa34b Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Tue, 29 Jul 2014 11:11:29 -0700 Subject: [PATCH 011/170] [STREAMING] SPARK-1729. Make Flume pull data from source, rather than the current pu... ...sh model Currently Spark uses Flume's internal Avro Protocol to ingest data from Flume. If the executor running the receiver fails, it currently has to be restarted on the same node to be able to receive data. This commit adds a new Sink which can be deployed to a Flume agent. This sink can be polled by a new DStream that is also included in this commit. This model ensures that data can be pulled into Spark from Flume even if the receiver is restarted on a new node. This also allows the receiver to receive data on multiple threads for better performance. Author: Hari Shreedharan Author: Hari Shreedharan Author: Tathagata Das Author: harishreedharan Closes #807 from harishreedharan/master and squashes the following commits: e7f70a3 [Hari Shreedharan] Merge remote-tracking branch 'asf-git/master' 96cfb6f [Hari Shreedharan] Merge remote-tracking branch 'asf/master' e48d785 [Hari Shreedharan] Documenting flume-sink being ignored for Mima checks. 5f212ce [Hari Shreedharan] Ignore Spark Sink from mima. 981bf62 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' 7a1bc6e [Hari Shreedharan] Fix SparkBuild.scala a082eb3 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' 1f47364 [Hari Shreedharan] Minor fixes. 73d6f6d [Hari Shreedharan] Cleaned up tests a bit. Added some docs in multiple places. 65b76b4 [Hari Shreedharan] Fixing the unit test. e59cc20 [Hari Shreedharan] Use SparkFlumeEvent instead of the new type. Also, Flume Polling Receiver now uses the store(ArrayBuffer) method. f3c99d1 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' 3572180 [Hari Shreedharan] Adding a license header, making Jenkins happy. 799509f [Hari Shreedharan] Fix a compile issue. 3c5194c [Hari Shreedharan] Merge remote-tracking branch 'asf/master' d248d22 [harishreedharan] Merge pull request #1 from tdas/flume-polling 10b6214 [Tathagata Das] Changed public API, changed sink package, and added java unit test to make sure Java API is callable from Java. 1edc806 [Hari Shreedharan] SPARK-1729. Update logging in Spark Sink. 8c00289 [Hari Shreedharan] More debug messages 393bd94 [Hari Shreedharan] SPARK-1729. Use LinkedBlockingQueue instead of ArrayBuffer to keep track of connections. 120e2a1 [Hari Shreedharan] SPARK-1729. Some test changes and changes to utils classes. 9fd0da7 [Hari Shreedharan] SPARK-1729. Use foreach instead of map for all Options. 8136aa6 [Hari Shreedharan] Adding TransactionProcessor to map on returning batch of data 86aa274 [Hari Shreedharan] Merge remote-tracking branch 'asf/master' 205034d [Hari Shreedharan] Merging master in 4b0c7fc [Hari Shreedharan] FLUME-1729. New Flume-Spark integration. bda01fc [Hari Shreedharan] FLUME-1729. Flume-Spark integration. 0d69604 [Hari Shreedharan] FLUME-1729. Better Flume-Spark integration. 3c23c18 [Hari Shreedharan] SPARK-1729. New Spark-Flume integration. 70bcc2a [Hari Shreedharan] SPARK-1729. New Flume-Spark integration. d6fa3aa [Hari Shreedharan] SPARK-1729. New Flume-Spark integration. e7da512 [Hari Shreedharan] SPARK-1729. Fixing import order 9741683 [Hari Shreedharan] SPARK-1729. Fixes based on review. c604a3c [Hari Shreedharan] SPARK-1729. Optimize imports. 0f10788 [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model 87775aa [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model 8df37e4 [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model 03d6c1c [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model 08176ad [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model d24d9d4 [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model 6d6776a [Hari Shreedharan] SPARK-1729. Make Flume pull data from source, rather than the current push model --- .../streaming/FlumePollingEventCount.scala | 67 +++++ external/flume-sink/pom.xml | 100 ++++++++ .../flume-sink/src/main/avro/sparkflume.avdl | 40 +++ .../spark/streaming/flume/sink/Logging.scala | 125 ++++++++++ .../flume/sink/SparkAvroCallbackHandler.scala | 131 ++++++++++ .../streaming/flume/sink/SparkSink.scala | 154 ++++++++++++ .../streaming/flume/sink/SparkSinkUtils.scala | 28 +++ .../flume/sink/TransactionProcessor.scala | 228 ++++++++++++++++++ external/flume/pom.xml | 5 + .../streaming/flume/EventTransformer.scala | 72 ++++++ .../streaming/flume/FlumeInputDStream.scala | 3 - .../flume/FlumePollingInputDStream.scala | 178 ++++++++++++++ .../spark/streaming/flume/FlumeUtils.scala | 144 ++++++++++- .../flume/JavaFlumePollingStreamSuite.java | 44 ++++ .../flume/FlumePollingStreamSuite.scala | 195 +++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 20 +- project/plugins.sbt | 2 + 18 files changed, 1524 insertions(+), 13 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala create mode 100644 external/flume-sink/pom.xml create mode 100644 external/flume-sink/src/main/avro/sparkflume.avdl create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala create mode 100644 external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala create mode 100644 external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java create mode 100644 external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala new file mode 100644 index 0000000000000..1cc8c8d5c23b6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.flume._ +import org.apache.spark.util.IntParam +import java.net.InetSocketAddress + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with the Spark Sink running in a Flume agent. See + * the Spark Streaming programming guide for more details. + * + * Usage: FlumePollingEventCount + * `host` is the host on which the Spark Sink is running. + * `port` is the port at which the Spark Sink is listening. + * + * To run this example: + * `$ bin/run-example org.apache.spark.examples.streaming.FlumePollingEventCount [host] [port] ` + */ +object FlumePollingEventCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println( + "Usage: FlumePollingEventCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + + // Create the context and set the batch size + val sparkConf = new SparkConf().setAppName("FlumePollingEventCount") + val ssc = new StreamingContext(sparkConf, batchInterval) + + // Create a flume stream that polls the Spark Sink running in a Flume agent + val stream = FlumeUtils.createPollingStream(ssc, host, port) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml new file mode 100644 index 0000000000000..d11129ce8d89d --- /dev/null +++ b/external/flume-sink/pom.xml @@ -0,0 +1,100 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + spark-streaming-flume-sink_2.10 + + streaming-flume-sink + + + jar + Spark Project External Flume Sink + http://spark.apache.org/ + + + org.apache.flume + flume-ng-sdk + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.apache.flume + flume-ng-core + 1.4.0 + + + io.netty + netty + + + org.apache.thrift + libthrift + + + + + org.scala-lang + scala-library + 2.10.4 + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.avro + avro-maven-plugin + 1.7.3 + + + ${project.basedir}/target/scala-${scala.binary.version}/src_managed/main/compiled_avro + + + + generate-sources + + idl-protocol + + + + + + + diff --git a/external/flume-sink/src/main/avro/sparkflume.avdl b/external/flume-sink/src/main/avro/sparkflume.avdl new file mode 100644 index 0000000000000..8806e863ac7c6 --- /dev/null +++ b/external/flume-sink/src/main/avro/sparkflume.avdl @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@namespace("org.apache.spark.streaming.flume.sink") + +protocol SparkFlumeProtocol { + + record SparkSinkEvent { + map headers; + bytes body; + } + + record EventBatch { + string errorMsg = ""; // If this is empty it is a valid message, else it represents an error + string sequenceNumber; + array events; + } + + EventBatch getEventBatch (int n); + + void ack (string sequenceNumber); + + void nack (string sequenceNumber); +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala new file mode 100644 index 0000000000000..17cbc6707b5ea --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import org.slf4j.{Logger, LoggerFactory} + +/** + * Copy of the org.apache.spark.Logging for being used in the Spark Sink. + * The org.apache.spark.Logging is not used so that all of Spark is not brought + * in as a dependency. + */ +private[sink] trait Logging { + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + initializeIfNecessary() + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + if (className.endsWith("$")) { + className = className.substring(0, className.length - 1) + } + log_ = LoggerFactory.getLogger(className) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + + private def initializeIfNecessary() { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging() + } + } + } + } + + private def initializeLogging() { + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +private[sink] object Logging { + @volatile private var initialized = false + val initLock = new Object() + try { + // We use reflection here to handle the case where users remove the + // slf4j-to-jul bridge order to route their logs to JUL. + val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) + val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] + if (!installed) { + bridgeClass.getMethod("install").invoke(null) + } + } catch { + case e: ClassNotFoundException => // can't log anything yet so just fail silently + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala new file mode 100644 index 0000000000000..7da8eb3e35912 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.atomic.AtomicLong + +import org.apache.flume.Channel +import org.apache.commons.lang.RandomStringUtils +import com.google.common.util.concurrent.ThreadFactoryBuilder + +/** + * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process + * requests. Each getEvents, ack and nack call is forwarded to an instance of this class. + * @param threads Number of threads to use to process requests. + * @param channel The channel that the sink pulls events from + * @param transactionTimeout Timeout in millis after which the transaction if not acked by Spark + * is rolled back. + */ +// Flume forces transactions to be thread-local. So each transaction *must* be committed, or +// rolled back from the thread it was originally created in. So each getEvents call from Spark +// creates a TransactionProcessor which runs in a new thread, in which the transaction is created +// and events are pulled off the channel. Once the events are sent to spark, +// that thread is blocked and the TransactionProcessor is saved in a map, +// until an ACK or NACK comes back or the transaction times out (after the specified timeout). +// When the response comes or a timeout is hit, the TransactionProcessor is retrieved and then +// unblocked, at which point the transaction is committed or rolled back. + +private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, + val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { + val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, + new ThreadFactoryBuilder().setDaemon(true) + .setNameFormat("Spark Sink Processor Thread - %d").build())) + private val processorMap = new ConcurrentHashMap[CharSequence, TransactionProcessor]() + // This sink will not persist sequence numbers and reuses them if it gets restarted. + // So it is possible to commit a transaction which may have been meant for the sink before the + // restart. + // Since the new txn may not have the same sequence number we must guard against accidentally + // committing a new transaction. To reduce the probability of that happening a random string is + // prepended to the sequence number. Does not change for life of sink + private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqCounter = new AtomicLong(0) + + /** + * Returns a bunch of events to Spark over Avro RPC. + * @param n Maximum number of events to return in a batch + * @return [[EventBatch]] instance that has a sequence number and an array of at most n events + */ + override def getEventBatch(n: Int): EventBatch = { + logDebug("Got getEventBatch call from Spark.") + val sequenceNumber = seqBase + seqCounter.incrementAndGet() + val processor = new TransactionProcessor(channel, sequenceNumber, + n, transactionTimeout, backOffInterval, this) + transactionExecutorOpt.foreach(executor => { + executor.submit(processor) + }) + // Wait until a batch is available - will be an error if error message is non-empty + val batch = processor.getEventBatch + if (!SparkSinkUtils.isErrorBatch(batch)) { + processorMap.put(sequenceNumber.toString, processor) + logDebug("Sending event batch with sequence number: " + sequenceNumber) + } + batch + } + + /** + * Called by Spark to indicate successful commit of a batch + * @param sequenceNumber The sequence number of the event batch that was successful + */ + override def ack(sequenceNumber: CharSequence): Void = { + logDebug("Received Ack for batch with sequence number: " + sequenceNumber) + completeTransaction(sequenceNumber, success = true) + null + } + + /** + * Called by Spark to indicate failed commit of a batch + * @param sequenceNumber The sequence number of the event batch that failed + * @return + */ + override def nack(sequenceNumber: CharSequence): Void = { + completeTransaction(sequenceNumber, success = false) + logInfo("Spark failed to commit transaction. Will reattempt events.") + null + } + + /** + * Helper method to commit or rollback a transaction. + * @param sequenceNumber The sequence number of the batch that was completed + * @param success Whether the batch was successful or not. + */ + private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { + Option(removeAndGetProcessor(sequenceNumber)).foreach(processor => { + processor.batchProcessed(success) + }) + } + + /** + * Helper method to remove the TxnProcessor for a Sequence Number. Can be used to avoid a leak. + * @param sequenceNumber + * @return The transaction processor for the corresponding batch. Note that this instance is no + * longer tracked and the caller is responsible for that txn processor. + */ + private[sink] def removeAndGetProcessor(sequenceNumber: CharSequence): TransactionProcessor = { + processorMap.remove(sequenceNumber.toString) // The toString is required! + } + + /** + * Shuts down the executor used to process transactions. + */ + def shutdown() { + logInfo("Shutting down Spark Avro Callback Handler") + transactionExecutorOpt.foreach(executor => { + executor.shutdownNow() + }) + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala new file mode 100644 index 0000000000000..7b735133e3d14 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.net.InetSocketAddress +import java.util.concurrent._ + +import org.apache.avro.ipc.NettyServer +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.Context +import org.apache.flume.Sink.Status +import org.apache.flume.conf.{Configurable, ConfigurationException} +import org.apache.flume.sink.AbstractSink + +/** + * A sink that uses Avro RPC to run a server that can be polled by Spark's + * FlumePollingInputDStream. This sink has the following configuration parameters: + * + * hostname - The hostname to bind to. Default: 0.0.0.0 + * port - The port to bind to. (No default - mandatory) + * timeout - Time in seconds after which a transaction is rolled back, + * if an ACK is not received from Spark within that time + * threads - Number of threads to use to receive requests from Spark (Default: 10) + * + * This sink is unlike other Flume sinks in the sense that it does not push data, + * instead the process method in this sink simply blocks the SinkRunner the first time it is + * called. This sink starts up an Avro IPC server that uses the SparkFlumeProtocol. + * + * Each time a getEventBatch call comes, creates a transaction and reads events + * from the channel. When enough events are read, the events are sent to the Spark receiver and + * the thread itself is blocked and a reference to it saved off. + * + * When the ack for that batch is received, + * the thread which created the transaction is is retrieved and it commits the transaction with the + * channel from the same thread it was originally created in (since Flume transactions are + * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack + * is received within the specified timeout, the transaction is rolled back too. If an ack comes + * after that, it is simply ignored and the events get re-sent. + * + */ + +private[flume] +class SparkSink extends AbstractSink with Logging with Configurable { + + // Size of the pool to use for holding transaction processors. + private var poolSize: Integer = SparkSinkConfig.DEFAULT_THREADS + + // Timeout for each transaction. If spark does not respond in this much time, + // rollback the transaction + private var transactionTimeout = SparkSinkConfig.DEFAULT_TRANSACTION_TIMEOUT + + // Address info to bind on + private var hostname: String = SparkSinkConfig.DEFAULT_HOSTNAME + private var port: Int = 0 + + private var backOffInterval: Int = 200 + + // Handle to the server + private var serverOpt: Option[NettyServer] = None + + // The handler that handles the callback from Avro + private var handler: Option[SparkAvroCallbackHandler] = None + + // Latch that blocks off the Flume framework from wasting 1 thread. + private val blockingLatch = new CountDownLatch(1) + + override def start() { + logInfo("Starting Spark Sink: " + getName + " on port: " + port + " and interface: " + + hostname + " with " + "pool size: " + poolSize + " and transaction timeout: " + + transactionTimeout + ".") + handler = Option(new SparkAvroCallbackHandler(poolSize, getChannel, transactionTimeout, + backOffInterval)) + val responder = new SpecificResponder(classOf[SparkFlumeProtocol], handler.get) + // Using the constructor that takes specific thread-pools requires bringing in netty + // dependencies which are being excluded in the build. In practice, + // Netty dependencies are already available on the JVM as Flume would have pulled them in. + serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) + serverOpt.foreach(server => { + logInfo("Starting Avro server for sink: " + getName) + server.start() + }) + super.start() + } + + override def stop() { + logInfo("Stopping Spark Sink: " + getName) + handler.foreach(callbackHandler => { + callbackHandler.shutdown() + }) + serverOpt.foreach(server => { + logInfo("Stopping Avro Server for sink: " + getName) + server.close() + server.join() + }) + blockingLatch.countDown() + super.stop() + } + + override def configure(ctx: Context) { + import SparkSinkConfig._ + hostname = ctx.getString(CONF_HOSTNAME, DEFAULT_HOSTNAME) + port = Option(ctx.getInteger(CONF_PORT)). + getOrElse(throw new ConfigurationException("The port to bind to must be specified")) + poolSize = ctx.getInteger(THREADS, DEFAULT_THREADS) + transactionTimeout = ctx.getInteger(CONF_TRANSACTION_TIMEOUT, DEFAULT_TRANSACTION_TIMEOUT) + backOffInterval = ctx.getInteger(CONF_BACKOFF_INTERVAL, DEFAULT_BACKOFF_INTERVAL) + logInfo("Configured Spark Sink with hostname: " + hostname + ", port: " + port + ", " + + "poolSize: " + poolSize + ", transactionTimeout: " + transactionTimeout + ", " + + "backoffInterval: " + backOffInterval) + } + + override def process(): Status = { + // This method is called in a loop by the Flume framework - block it until the sink is + // stopped to save CPU resources. The sink runner will interrupt this thread when the sink is + // being shut down. + logInfo("Blocking Sink Runner, sink will continue to run..") + blockingLatch.await() + Status.BACKOFF + } +} + +/** + * Configuration parameters and their defaults. + */ +private[flume] +object SparkSinkConfig { + val THREADS = "threads" + val DEFAULT_THREADS = 10 + + val CONF_TRANSACTION_TIMEOUT = "timeout" + val DEFAULT_TRANSACTION_TIMEOUT = 60 + + val CONF_HOSTNAME = "hostname" + val DEFAULT_HOSTNAME = "0.0.0.0" + + val CONF_PORT = "port" + + val CONF_BACKOFF_INTERVAL = "backoffInterval" + val DEFAULT_BACKOFF_INTERVAL = 200 +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala new file mode 100644 index 0000000000000..47c0e294d6b52 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +private[flume] object SparkSinkUtils { + /** + * This method determines if this batch represents an error or not. + * @param batch - The batch to check + * @return - true if the batch represents an error + */ + def isErrorBatch(batch: EventBatch): Boolean = { + !batch.getErrorMsg.toString.equals("") // If there is an error message, it is an error batch. + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala new file mode 100644 index 0000000000000..b9e3c786ebb3b --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.nio.ByteBuffer +import java.util +import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} + +import scala.util.control.Breaks + +import org.apache.flume.{Transaction, Channel} + +// Flume forces transactions to be thread-local (horrible, I know!) +// So the sink basically spawns a new thread to pull the events out within a transaction. +// The thread fills in the event batch object that is set before the thread is scheduled. +// After filling it in, the thread waits on a condition - which is released only +// when the success message comes back for the specific sequence number for that event batch. +/** + * This class represents a transaction on the Flume channel. This class runs a separate thread + * which owns the transaction. The thread is blocked until the success call for that transaction + * comes back with an ACK or NACK. + * @param channel The channel from which to pull events + * @param seqNum The sequence number to use for the transaction. Must be unique + * @param maxBatchSize The maximum number of events to process per batch + * @param transactionTimeout Time in seconds after which a transaction must be rolled back + * without waiting for an ACK from Spark + * @param parent The parent [[SparkAvroCallbackHandler]] instance, for reporting timeouts + */ +private class TransactionProcessor(val channel: Channel, val seqNum: String, + var maxBatchSize: Int, val transactionTimeout: Int, val backOffInterval: Int, + val parent: SparkAvroCallbackHandler) extends Callable[Void] with Logging { + + // If a real batch is not returned, we always have to return an error batch. + @volatile private var eventBatch: EventBatch = new EventBatch("Unknown Error", "", + util.Collections.emptyList()) + + // Synchronization primitives + val batchGeneratedLatch = new CountDownLatch(1) + val batchAckLatch = new CountDownLatch(1) + + // Sanity check to ensure we don't loop like crazy + val totalAttemptsToRemoveFromChannel = Int.MaxValue / 2 + + // OK to use volatile, since the change would only make this true (otherwise it will be + // changed to false - we never apply a negation operation to this) - which means the transaction + // succeeded. + @volatile private var batchSuccess = false + + // The transaction that this processor would handle + var txOpt: Option[Transaction] = None + + /** + * Get an event batch from the channel. This method will block until a batch of events is + * available from the channel. If no events are available after a large number of attempts of + * polling the channel, this method will return an [[EventBatch]] with a non-empty error message + * + * @return An [[EventBatch]] instance with sequence number set to seqNum, filled with a + * maximum of maxBatchSize events + */ + def getEventBatch: EventBatch = { + batchGeneratedLatch.await() + eventBatch + } + + /** + * This method is to be called by the sink when it receives an ACK or NACK from Spark. This + * method is a no-op if it is called after transactionTimeout has expired since + * getEventBatch returned a batch of events. + * @param success True if an ACK was received and the transaction should be committed, else false. + */ + def batchProcessed(success: Boolean) { + logDebug("Batch processed for sequence number: " + seqNum) + batchSuccess = success + batchAckLatch.countDown() + } + + /** + * Populates events into the event batch. If the batch cannot be populated, + * this method will not set the events into the event batch, but it sets an error message. + */ + private def populateEvents() { + try { + txOpt = Option(channel.getTransaction) + if(txOpt.isEmpty) { + eventBatch.setErrorMsg("Something went wrong. Channel was " + + "unable to create a transaction!") + } + txOpt.foreach(tx => { + tx.begin() + val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) + val loop = new Breaks + var gotEventsInThisTxn = false + var loopCounter: Int = 0 + loop.breakable { + while (events.size() < maxBatchSize + && loopCounter < totalAttemptsToRemoveFromChannel) { + loopCounter += 1 + Option(channel.take()) match { + case Some(event) => + events.add(new SparkSinkEvent(toCharSequenceMap(event.getHeaders), + ByteBuffer.wrap(event.getBody))) + gotEventsInThisTxn = true + case None => + if (!gotEventsInThisTxn) { + logDebug("Sleeping for " + backOffInterval + " millis as no events were read in" + + " the current transaction") + TimeUnit.MILLISECONDS.sleep(backOffInterval) + } else { + loop.break() + } + } + } + } + if (!gotEventsInThisTxn) { + val msg = "Tried several times, " + + "but did not get any events from the channel!" + logWarning(msg) + eventBatch.setErrorMsg(msg) + } else { + // At this point, the events are available, so fill them into the event batch + eventBatch = new EventBatch("",seqNum, events) + } + }) + } catch { + case e: Exception => + logWarning("Error while processing transaction.", e) + eventBatch.setErrorMsg(e.getMessage) + try { + txOpt.foreach(tx => { + rollbackAndClose(tx, close = true) + }) + } finally { + txOpt = None + } + } finally { + batchGeneratedLatch.countDown() + } + } + + /** + * Waits for upto transactionTimeout seconds for an ACK. If an ACK comes in + * this method commits the transaction with the channel. If the ACK does not come in within + * that time or a NACK comes in, this method rolls back the transaction. + */ + private def processAckOrNack() { + batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) + txOpt.foreach(tx => { + if (batchSuccess) { + try { + logDebug("Committing transaction") + tx.commit() + } catch { + case e: Exception => + logWarning("Error while attempting to commit transaction. Transaction will be rolled " + + "back", e) + rollbackAndClose(tx, close = false) // tx will be closed later anyway + } finally { + tx.close() + } + } else { + logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") + rollbackAndClose(tx, close = true) + // This might have been due to timeout or a NACK. Either way the following call does not + // cause issues. This is required to ensure the TransactionProcessor instance is not leaked + parent.removeAndGetProcessor(seqNum) + } + }) + } + + /** + * Helper method to rollback and optionally close a transaction + * @param tx The transaction to rollback + * @param close Whether the transaction should be closed or not after rolling back + */ + private def rollbackAndClose(tx: Transaction, close: Boolean) { + try { + logWarning("Spark was unable to successfully process the events. Transaction is being " + + "rolled back.") + tx.rollback() + } catch { + case e: Exception => + logError("Error rolling back transaction. Rollback may have failed!", e) + } finally { + if (close) { + tx.close() + } + } + } + + /** + * Helper method to convert a Map[String, String] to Map[CharSequence, CharSequence] + * @param inMap The map to be converted + * @return The converted map + */ + private def toCharSequenceMap(inMap: java.util.Map[String, String]): java.util.Map[CharSequence, + CharSequence] = { + val charSeqMap = new util.HashMap[CharSequence, CharSequence](inMap.size()) + charSeqMap.putAll(inMap) + charSeqMap + } + + /** + * When the thread is started it sets as many events as the batch size or less (if enough + * events aren't available) into the eventBatch and object and lets any threads waiting on the + * [[getEventBatch]] method to proceed. Then this thread waits for acks or nacks to come in, + * or for a specified timeout and commits or rolls back the transaction. + * @return + */ + override def call(): Void = { + populateEvents() + processAckOrNack() + null + } +} diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 874b8a7959bb6..9f680b27c3308 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -77,6 +77,11 @@ junit-interface test + + org.apache.spark + spark-streaming-flume-sink_2.10 + ${project.version} + target/scala-${scala.binary.version}/classes diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala new file mode 100644 index 0000000000000..dc629df4f4ac2 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.io.{ObjectOutput, ObjectInput} + +import scala.collection.JavaConversions._ + +import org.apache.spark.util.Utils +import org.apache.spark.Logging + +/** + * A simple object that provides the implementation of readExternal and writeExternal for both + * the wrapper classes for Flume-style Events. + */ +private[streaming] object EventTransformer extends Logging { + def readExternal(in: ObjectInput): (java.util.HashMap[CharSequence, CharSequence], + Array[Byte]) = { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.readFully(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.readFully(keyBuff) + val key: String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.readFully(valBuff) + val value: String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + (headers, bodyBuff) + } + + def writeExternal(out: ObjectOutput, headers: java.util.Map[CharSequence, CharSequence], + body: Array[Byte]) { + out.writeInt(body.length) + out.write(body) + val numHeaders = headers.size() + out.writeInt(numHeaders) + for ((k,v) <- headers) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 56d2886b26878..4b2ea45fb81d0 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -39,11 +39,8 @@ import org.apache.spark.streaming.receiver.Receiver import org.jboss.netty.channel.ChannelPipelineFactory import org.jboss.netty.channel.Channels -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.ChannelFactory import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.jboss.netty.handler.execution.ExecutionHandler private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala new file mode 100644 index 0000000000000..148262bb6771e --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume + + +import java.net.InetSocketAddress +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit, Executors} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.flume.sink._ + +/** + * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running + * [[org.apache.spark.streaming.flume.sink.SparkSink]]s. + * @param _ssc Streaming context that will execute this input stream + * @param addresses List of addresses at which SparkSinks are listening + * @param maxBatchSize Maximum size of a batch + * @param parallelism Number of parallel connections to open + * @param storageLevel The storage level to use. + * @tparam T Class type of the object of this stream + */ +private[streaming] class FlumePollingInputDStream[T: ClassTag]( + @transient _ssc: StreamingContext, + val addresses: Seq[InetSocketAddress], + val maxBatchSize: Int, + val parallelism: Int, + storageLevel: StorageLevel + ) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { + + override def getReceiver(): Receiver[SparkFlumeEvent] = { + new FlumePollingReceiver(addresses, maxBatchSize, parallelism, storageLevel) + } +} + +private[streaming] class FlumePollingReceiver( + addresses: Seq[InetSocketAddress], + maxBatchSize: Int, + parallelism: Int, + storageLevel: StorageLevel + ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { + + lazy val channelFactoryExecutor = + Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). + setNameFormat("Flume Receiver Channel Thread - %d").build()) + + lazy val channelFactory = + new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) + + lazy val receiverExecutor = Executors.newFixedThreadPool(parallelism, + new ThreadFactoryBuilder().setDaemon(true).setNameFormat("Flume Receiver Thread - %d").build()) + + private lazy val connections = new LinkedBlockingQueue[FlumeConnection]() + + override def onStart(): Unit = { + // Create the connections to each Flume agent. + addresses.foreach(host => { + val transceiver = new NettyTransceiver(host, channelFactory) + val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) + connections.add(new FlumeConnection(transceiver, client)) + }) + for (i <- 0 until parallelism) { + logInfo("Starting Flume Polling Receiver worker threads starting..") + // Threads that pull data from Flume. + receiverExecutor.submit(new Runnable { + override def run(): Unit = { + while (true) { + val connection = connections.poll() + val client = connection.client + try { + val eventBatch = client.getEventBatch(maxBatchSize) + if (!SparkSinkUtils.isErrorBatch(eventBatch)) { + // No error, proceed with processing data + val seq = eventBatch.getSequenceNumber + val events: java.util.List[SparkSinkEvent] = eventBatch.getEvents + logDebug( + "Received batch of " + events.size() + " events with sequence number: " + seq) + try { + // Convert each Flume event to a serializable SparkFlumeEvent + val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) + var j = 0 + while (j < events.size()) { + buffer += toSparkFlumeEvent(events(j)) + j += 1 + } + store(buffer) + logDebug("Sending ack for sequence number: " + seq) + // Send an ack to Flume so that Flume discards the events from its channels. + client.ack(seq) + logDebug("Ack sent for sequence number: " + seq) + } catch { + case e: Exception => + try { + // Let Flume know that the events need to be pushed back into the channel. + logDebug("Sending nack for sequence number: " + seq) + client.nack(seq) // If the agent is down, even this could fail and throw + logDebug("Nack sent for sequence number: " + seq) + } catch { + case e: Exception => logError( + "Sending Nack also failed. A Flume agent is down.") + } + TimeUnit.SECONDS.sleep(2L) // for now just leave this as a fixed 2 seconds. + logWarning("Error while attempting to store events", e) + } + } else { + logWarning("Did not receive events from Flume agent due to error on the Flume " + + "agent: " + eventBatch.getErrorMsg) + } + } catch { + case e: Exception => + logWarning("Error while reading data from Flume", e) + } finally { + connections.add(connection) + } + } + } + }) + } + } + + override def onStop(): Unit = { + logInfo("Shutting down Flume Polling Receiver") + receiverExecutor.shutdownNow() + connections.foreach(connection => { + connection.transceiver.close() + }) + channelFactory.releaseExternalResources() + } + + /** + * Utility method to convert [[SparkSinkEvent]] to [[SparkFlumeEvent]] + * @param event - Event to convert to SparkFlumeEvent + * @return - The SparkFlumeEvent generated from SparkSinkEvent + */ + private def toSparkFlumeEvent(event: SparkSinkEvent): SparkFlumeEvent = { + val sparkFlumeEvent = new SparkFlumeEvent() + sparkFlumeEvent.event.setBody(event.getBody) + sparkFlumeEvent.event.setHeaders(event.getHeaders) + sparkFlumeEvent + } +} + +/** + * A wrapper around the transceiver and the Avro IPC API. + * @param transceiver The transceiver to use for communication with Flume + * @param client The client that the callbacks are received on. + */ +private class FlumeConnection(val transceiver: NettyTransceiver, + val client: SparkFlumeProtocol.Callback) + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 716db9fa76031..4b732c1592ab2 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,12 +17,19 @@ package org.apache.spark.streaming.flume +import java.net.InetSocketAddress + +import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream + object FlumeUtils { + private val DEFAULT_POLLING_PARALLELISM = 5 + private val DEFAULT_POLLING_BATCH_SIZE = 1000 + /** * Create a input stream from a Flume source. * @param ssc StreamingContext object @@ -56,7 +63,7 @@ object FlumeUtils { ): ReceiverInputDStream[SparkFlumeEvent] = { val inputStream = new FlumeInputDStream[SparkFlumeEvent]( ssc, hostname, port, storageLevel, enableDecompression) - + inputStream } @@ -105,4 +112,135 @@ object FlumeUtils { ): JavaReceiverInputDStream[SparkFlumeEvent] = { createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Address of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, Seq(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel + ): ReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(ssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses representing the hosts to connect to. + * @param maxBatchSize Maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + ssc: StreamingContext, + addresses: Seq[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): ReceiverInputDStream[SparkFlumeEvent] = { + new FlumePollingInputDStream[SparkFlumeEvent](ssc, addresses, maxBatchSize, + parallelism, storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, hostname, port, StorageLevel.MEMORY_AND_DISK_SER_2) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param hostname Hostname of the host on which the Spark Sink is running + * @param port Port of the host at which the Spark Sink is listening + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, Array(new InetSocketAddress(hostname, port)), storageLevel) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * This stream will use a batch size of 1000 events and run 5 threads to pull data. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running. + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc, addresses, storageLevel, + DEFAULT_POLLING_BATCH_SIZE, DEFAULT_POLLING_PARALLELISM) + } + + /** + * Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + * This stream will poll the sink for data and will pull events as they are available. + * @param addresses List of InetSocketAddresses on which the Spark Sink is running + * @param maxBatchSize The maximum number of events to be pulled from the Spark sink in a + * single RPC call + * @param parallelism Number of concurrent requests this stream should send to the sink. Note + * that having a higher number of requests concurrently being pulled will + * result in this stream using more threads + * @param storageLevel Storage level to use for storing the received objects + */ + @Experimental + def createPollingStream( + jssc: JavaStreamingContext, + addresses: Array[InetSocketAddress], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + } } diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java new file mode 100644 index 0000000000000..79c5b91654b42 --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumePollingStreamSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume; + +import java.net.InetSocketAddress; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.LocalJavaStreamingContext; + +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.junit.Test; + +public class JavaFlumePollingStreamSuite extends LocalJavaStreamingContext { + @Test + public void testFlumeStream() { + // tests the API, does not actually test data receiving + InetSocketAddress[] addresses = new InetSocketAddress[] { + new InetSocketAddress("localhost", 12345) + }; + JavaReceiverInputDStream test1 = + FlumeUtils.createPollingStream(ssc, "localhost", 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createPollingStream( + ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test4 = FlumeUtils.createPollingStream( + ssc, addresses, StorageLevel.MEMORY_AND_DISK_SER_2(), 100, 5); + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala new file mode 100644 index 0000000000000..47071d0cc4714 --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.streaming.flume + +import java.net.InetSocketAddress +import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables +import org.apache.flume.event.EventBuilder + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.flume.sink._ + +class FlumePollingStreamSuite extends TestSuiteBase { + + val testPort = 9999 + val batchCount = 5 + val eventsPerBatch = 100 + val totalEventsPerChannel = batchCount * eventsPerBatch + val channelCapacity = 5000 + + test("flume polling test") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + ssc.start() + + writeAndVerify(Seq(channel), ssc, outputBuffer) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + } + + test("flume polling test multiple hosts") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + ssc.start() + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + sink.stop() + channel.stop() + } + + def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + channels.map(channel => { + executorCompletion.submit(new TxnSubmitter(channel, clock)) + }) + for (i <- 0 until channels.size) { + executorCompletion.take() + } + val startTime = System.currentTimeMillis() + while (outputBuffer.size < batchCount * channels.size && + System.currentTimeMillis() - startTime < 15000) { + logInfo("output.size = " + outputBuffer.size) + Thread.sleep(100) + } + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val flattenedBuffer = outputBuffer.flatten + assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + + String.valueOf(i)).getBytes("utf-8"), + Map[String, String]("test-" + i.toString -> "header")) + var found = false + var j = 0 + while (j < flattenedBuffer.size && !found) { + val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") + if (new String(eventToVerify.getBody, "utf-8") == strToCompare && + eventToVerify.getHeaders.get("test-" + i.toString) + .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { + found = true + counter += 1 + } + j += 1 + } + } + assert(counter === totalEventsPerChannel * channels.size) + } + + def assertChannelIsEmpty(channel: MemoryChannel) = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining"); + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) + } + + private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( + "utf-8"), + Map[String, String]("test-" + t.toString -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + clock.addToTime(batchDuration.milliseconds) + } + null + } + } +} diff --git a/pom.xml b/pom.xml index 93ef3b91b5bce..8b1435cfe5d19 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,7 @@ external/twitter external/kafka external/flume + external/flume-sink external/zeromq external/mqtt examples diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1629bc2cba8ba..0a6326e72297a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,12 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql, - streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, + streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", - "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") @@ -156,10 +157,9 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)). - foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, + streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -173,6 +173,8 @@ object SparkBuild extends PomBuild { /* Hive console settings */ enable(Hive.settings)(hive) + enable(Flume.settings)(streamingFlumeSink) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -183,6 +185,10 @@ object SparkBuild extends PomBuild { } +object Flume { + lazy val settings = sbtavro.SbtAvro.avroSettings +} + object SQL { lazy val settings = Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index d3ac4bf335e87..06d18e193076e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -24,3 +24,5 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") + +addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") From 0c5c6a63d19bed2a813a09309c46971ecdd173f0 Mon Sep 17 00:00:00 2001 From: Daoyuan Date: Tue, 29 Jul 2014 12:22:48 -0700 Subject: [PATCH 012/170] [SQL]change some test lists 1. there's no `hook_context.q` but a `hook_context_cs.q` in query folder 2. there's no `compute_stats_table.q` in query folder 3. there's no `having1.q` in query folder 4. `udf_E` and `udf_PI` appear twice in white list Author: Daoyuan Closes #1634 from adrian-wang/testcases and squashes the following commits: d7482ce [Daoyuan] change some test lists --- .../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c69e93ba2b9ba..4fef071161719 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -52,7 +52,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def blackList = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", - "hook_context", + "hook_context_cs", "mapjoin_hook", "multi_sahooks", "overridden_confs", @@ -289,7 +289,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_empty_table", "compute_stats_long", "compute_stats_string", - "compute_stats_table", "convert_enum_to_string", "correlationoptimizer1", "correlationoptimizer10", @@ -395,7 +394,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_9", "groupby_sort_test_1", "having", - "having1", "implicit_cast1", "innerjoin", "inoutdriver", @@ -697,8 +695,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf7", "udf8", "udf9", - "udf_E", - "udf_PI", "udf_abs", "udf_acos", "udf_add", From e3643485de8fdaf5c52b266fead1b13214f29d5e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 12:23:34 -0700 Subject: [PATCH 013/170] [SPARK-2730][SQL] When retrieving a value from a Map, GetItem evaluates key twice JIRA: https://issues.apache.org/jira/browse/SPARK-2730 Author: Yin Huai Closes #1637 from yhuai/SPARK-2730 and squashes the following commits: 1a9f24e [Yin Huai] Remove unnecessary key evaluation. --- .../org/apache/spark/sql/catalyst/expressions/complexTypes.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 5d3bb25ad568c..0acb29012f314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -61,7 +61,6 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } else { val baseValue = value.asInstanceOf[Map[Any, _]] - val key = ordinal.eval(input) baseValue.get(key).orNull } } From f0d880e288eba97c86dceb1b5edab4f3a935943b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 29 Jul 2014 12:31:39 -0700 Subject: [PATCH 014/170] [SPARK-2674] [SQL] [PySpark] support datetime type for SchemaRDD Datetime and time in Python will be converted into java.util.Calendar after serialization, it will be converted into java.sql.Timestamp during inferSchema(). In javaToPython(), Timestamp will be converted into Calendar, then be converted into datetime in Python after pickling. Author: Davies Liu Closes #1601 from davies/date and squashes the following commits: f0599b0 [Davies Liu] remove tests for sets and tuple in sql, fix list of list c9d607a [Davies Liu] convert datetype for runtime 709d40d [Davies Liu] remove brackets 96db384 [Davies Liu] support datetime type for SchemaRDD --- .../apache/spark/api/python/PythonRDD.scala | 4 +- python/pyspark/sql.py | 22 +++++---- .../org/apache/spark/sql/SQLContext.scala | 40 ++++++++++++++-- .../org/apache/spark/sql/SchemaRDD.scala | 46 +++++++------------ 4 files changed, 68 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d87783efd2d01..0d8453fb184a3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler - // TODO: Figure out why flatMap is necessay for pyspark iter.flatMap { row => unpickle.loads(row) match { + // in case of objects are pickled in batch mode case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // Incase the partition doesn't have a collection + // not in batch mode case obj: JMap[String @unchecked, _] => Seq(obj.toMap) } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..a6b3277db3266 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -47,12 +47,14 @@ def __init__(self, sparkContext, sqlContext=None): ... ValueError:... - >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, - ... "boolean" : True}]) + >>> from datetime import datetime + >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, + ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, + ... "list": [1, 2, 3]}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean)) + ... x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -88,13 +90,13 @@ def inferSchema(self, rdd): >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}}, + ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}] True >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}] True """ if (rdd.__class__ is SchemaRDD): @@ -509,8 +511,8 @@ def _test(): {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, + {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd27..c178dad662532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: java.lang.Long => LongType case c: java.lang.Double => DoubleType case c: java.lang.Boolean => BooleanType + case c: java.math.BigDecimal => DecimalType + case c: java.sql.Timestamp => TimestampType + case c: java.util.Calendar => TimestampType case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head MapType(typeFor(key), typeFor(value)) @@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext) ArrayType(typeFor(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } - val schema = rdd.first().map { case (fieldName, obj) => + val firstRow = rdd.first() + val schema = firstRow.map { case (fieldName, obj) => AttributeReference(fieldName, typeFor(obj), true)() }.toSeq - val rowRdd = rdd.mapPartitions { iter => + def needTransform(obj: Any): Boolean = obj match { + case c: java.util.List[_] => true + case c: java.util.Map[_, _] => true + case c if c.getClass.isArray => true + case c: java.util.Calendar => true + case c => false + } + + // convert JList, JArray into Seq, convert JMap into Map + // convert Calendar into Timestamp + def transform(obj: Any): Any = obj match { + case c: java.util.List[_] => c.map(transform).toSeq + case c: java.util.Map[_, _] => c.map { + case (key, value) => (key, transform(value)) + }.toMap + case c if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(transform).toSeq + case c: java.util.Calendar => + new java.sql.Timestamp(c.getTime().getTime()) + case c => c + } + + val need = firstRow.exists {case (key, value) => needTransform(value)} + val transformed = if (need) { + rdd.mapPartitions { iter => + iter.map { + m => m.map {case (key, value) => (key, transform(value))} + } + } + } else rdd + + val rowRdd = transformed.mapPartitions { iter => iter.map { map => new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f0571..019ff9d300a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} +import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -376,39 +376,27 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + def toJava(obj: Any, dataType: DataType): Any = dataType match { + case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) + case array: ArrayType => obj match { + case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava + case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + case other => other + } + case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + // Pyrolite can handle Timestamp + case other => obj + } def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (attrName, dataType)) => - dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { - case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) - } - case other => other - } - map.put(attrName, arrayValues) - case array: ArrayType => { - val arrayValues = obj match { - case seq: Seq[Any] => seq.asJava - case other => other - } - map.put(attrName, arrayValues) - } - case other => map.put(attrName, obj) - } + case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType)) } - map } From dc9653641f8806960d79652afa043c3fb84f25d2 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 29 Jul 2014 12:49:44 -0700 Subject: [PATCH 015/170] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size Implemented stratified sampling that guarantees exact sample size using ScaRSR with two passes over the RDD for sampling without replacement and three passes for sampling with replacement. Author: Doris Xin Author: Xiangrui Meng Closes #1025 from dorx/stratified and squashes the following commits: 245439e [Doris Xin] moved minSamplingRate to getUpperBound eaf5771 [Doris Xin] bug fixes. 17a381b [Doris Xin] fixed a merge issue and a failed unit ea7d27f [Doris Xin] merge master b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java b3013a4 [Xiangrui Meng] move math3 back to test scope eecee5f [Doris Xin] Merge branch 'master' into stratified f4c21f3 [Doris Xin] Reviewer comments a10e68d [Doris Xin] style fix a2bf756 [Doris Xin] Merge branch 'master' into stratified 680b677 [Doris Xin] use mapPartitionWithIndex instead 9884a9f [Doris Xin] style fix bbfb8c9 [Doris Xin] Merge branch 'master' into stratified ee9d260 [Doris Xin] addressed reviewer comments 6b5b10b [Doris Xin] Merge branch 'master' into stratified 254e03c [Doris Xin] minor fixes and Java API. 4ad516b [Doris Xin] remove unused imports from PairRDDFunctions bd9dc6e [Doris Xin] unit bug and style violation fixed 1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check 944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate 0214a76 [Doris Xin] cleanUp 90d94c0 [Doris Xin] merge master 9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey 7327611 [Doris Xin] merge master 50581fc [Doris Xin] added a TODO for logging in python 46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function 7e1a481 [Doris Xin] changed the permission on SamplingUtil 1d413ce [Doris Xin] fixed checkstyle issues 9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS --- .../apache/spark/api/java/JavaPairRDD.scala | 69 +++- .../apache/spark/rdd/PairRDDFunctions.scala | 54 ++- .../spark/util/random/SamplingUtils.scala | 74 +++- .../util/random/StratifiedSamplingUtils.scala | 316 ++++++++++++++++++ .../java/org/apache/spark/JavaAPISuite.java | 37 ++ .../spark/rdd/PairRDDFunctionsSuite.scala | 116 +++++++ pom.xml | 6 + 7 files changed, 656 insertions(+), 16 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 4f3081433a542..31bf8dced2638 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ @@ -129,6 +129,73 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean, + seed: Long): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over + * the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + exact: Boolean): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + */ + def sampleByKey(withReplacement: Boolean, + fractions: JMap[K, Double], + seed: Long): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, seed) + + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * Produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via + * simple random sampling. + * + * Use Utils.random.nextLong as the default seed for the random number generator + */ + def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c04d162a39616..1af4e5f0b6d08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,12 +19,10 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.Date -import java.util.{HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap} +import scala.collection.{Map, mutable} import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -34,19 +32,19 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.spark._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.SparkHadoopWriter import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.random.StratifiedSamplingUtils /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -195,6 +193,41 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) foldByKey(zeroValue, defaultPartitioner(self))(func) } + /** + * Return a subset of this RDD sampled by key (via stratified sampling). + * + * Create a sample of this RDD using variable sampling rates for different keys as specified by + * `fractions`, a key to sampling rate map. + * + * If `exact` is set to false, create the sample via simple random sampling, with one pass + * over the RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values; otherwise, use + * additional passes over the RDD to create a sample size that's exactly equal to the sum of + * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling + * without replacement, we need one additional pass over the RDD to guarantee sample size; + * when sampling with replacement, we need two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key + * @return RDD containing the sampled subset + */ + def sampleByKey(withReplacement: Boolean, + fractions: Map[K, Double], + exact: Boolean = false, + seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a @@ -531,6 +564,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. + * + * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only + * one value per key is preserved in the map returned) */ def collectAsMap(): Map[K, V] = { val data = self.collect() diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index d10141b90e621..c9a864ae62778 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -81,6 +81,9 @@ private[spark] object SamplingUtils { * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success * rate, where success rate is defined the same as in sampling with replacement. * + * The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the + * RNG's resolution). + * * @param sampleSizeLowerBound sample size * @param total size of RDD * @param withReplacement whether sampling with replacement @@ -88,14 +91,73 @@ private[spark] object SamplingUtils { */ def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = { - val fraction = sampleSizeLowerBound.toDouble / total if (withReplacement) { - val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 - fraction + numStDev * math.sqrt(fraction / total) + PoissonBounds.getUpperBound(sampleSizeLowerBound) / total } else { - val delta = 1e-4 - val gamma = - math.log(delta) / total - math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + val fraction = sampleSizeLowerBound.toDouble / total + BinomialBounds.getUpperBound(1e-4, total, fraction) } } } + +/** + * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact + * sample sizes with high confidence when sampling with replacement. + */ +private[spark] object PoissonBounds { + + /** + * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda). + */ + def getLowerBound(s: Double): Double = { + math.max(s - numStd(s) * math.sqrt(s), 1e-15) + } + + /** + * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda). + * + * @param s sample size + */ + def getUpperBound(s: Double): Double = { + math.max(s + numStd(s) * math.sqrt(s), 1e-10) + } + + private def numStd(s: Double): Double = { + // TODO: Make it tighter. + if (s < 6.0) { + 12.0 + } else if (s < 16.0) { + 9.0 + } else { + 6.0 + } + } +} + +/** + * Utility functions that help us determine bounds on adjusted sampling rate to guarantee exact + * sample size with high confidence when sampling without replacement. + */ +private[spark] object BinomialBounds { + + val minSamplingRate = 1e-10 + + /** + * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`, + * it is very unlikely to have more than `fraction * n` successes. + */ + def getLowerBound(delta: Double, n: Long, fraction: Double): Double = { + val gamma = - math.log(delta) / n * (2.0 / 3.0) + fraction + gamma - math.sqrt(gamma * gamma + 3 * gamma * fraction) + } + + /** + * Returns a threshold `p` such that if we conduct n Bernoulli trials with success rate = `p`, + * it is very unlikely to have less than `fraction * n` successes. + */ + def getUpperBound(delta: Double, n: Long, fraction: Double): Double = { + val gamma = - math.log(delta) / n + math.min(1, + math.max(minSamplingRate, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala new file mode 100644 index 0000000000000..8f95d7c6b799b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import scala.collection.Map +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD + +/** + * Auxiliary functions and data structures for the sampleByKey method in PairRDDFunctions. + * + * Essentially, when exact sample size is necessary, we make additional passes over the RDD to + * compute the exact threshold value to use for each stratum to guarantee exact sample size with + * high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the + * desired sample size for each stratum. + * + * Like in simple random sampling, we generate a random value for each item from the + * uniform distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist) + * are accepted into the sample instantly. The threshold for instant accept is designed so that + * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a + * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding + * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold + * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted). + * + * Note that since we use the same seed for the RNG when computing the thresholds and the actual + * sample, our computed thresholds are guaranteed to produce the desired sample size. + * + * For more theoretical background on the sampling techniques used here, please refer to + * http://jmlr.org/proceedings/papers/v28/meng13a.html + */ + +private[spark] object StratifiedSamplingUtils extends Logging { + + /** + * Count the number of items instantly accepted and generate the waitlist for each stratum. + * + * This is only invoked when exact sample size is required. + */ + def getAcceptanceResults[K, V](rdd: RDD[(K, V)], + withReplacement: Boolean, + fractions: Map[K, Double], + counts: Option[Map[K, Long]], + seed: Long): mutable.Map[K, AcceptanceResult] = { + val combOp = getCombOp[K] + val mappedPartitionRDD = rdd.mapPartitionsWithIndex { case (partition, iter) => + val zeroU: mutable.Map[K, AcceptanceResult] = new mutable.HashMap[K, AcceptanceResult]() + val rng = new RandomDataGenerator() + rng.reSeed(seed + partition) + val seqOp = getSeqOp(withReplacement, fractions, rng, counts) + Iterator(iter.aggregate(zeroU)(seqOp, combOp)) + } + mappedPartitionRDD.reduce(combOp) + } + + /** + * Returns the function used by aggregate to collect sampling statistics for each partition. + */ + def getSeqOp[K, V](withReplacement: Boolean, + fractions: Map[K, Double], + rng: RandomDataGenerator, + counts: Option[Map[K, Long]]): + (mutable.Map[K, AcceptanceResult], (K, V)) => mutable.Map[K, AcceptanceResult] = { + val delta = 5e-5 + (result: mutable.Map[K, AcceptanceResult], item: (K, V)) => { + val key = item._1 + val fraction = fractions(key) + if (!result.contains(key)) { + result += (key -> new AcceptanceResult()) + } + val acceptResult = result(key) + + if (withReplacement) { + // compute acceptBound and waitListBound only if they haven't been computed already + // since they don't change from iteration to iteration. + // TODO change this to the streaming version + if (acceptResult.areBoundsEmpty) { + val n = counts.get(key) + val sampleSize = math.ceil(n * fraction).toLong + val lmbd1 = PoissonBounds.getLowerBound(sampleSize) + val lmbd2 = PoissonBounds.getUpperBound(sampleSize) + acceptResult.acceptBound = lmbd1 / n + acceptResult.waitListBound = (lmbd2 - lmbd1) / n + } + val acceptBound = acceptResult.acceptBound + val copiesAccepted = if (acceptBound == 0.0) 0L else rng.nextPoisson(acceptBound) + if (copiesAccepted > 0) { + acceptResult.numAccepted += copiesAccepted + } + val copiesWaitlisted = rng.nextPoisson(acceptResult.waitListBound) + if (copiesWaitlisted > 0) { + acceptResult.waitList ++= ArrayBuffer.fill(copiesWaitlisted)(rng.nextUniform()) + } + } else { + // We use the streaming version of the algorithm for sampling without replacement to avoid + // using an extra pass over the RDD for computing the count. + // Hence, acceptBound and waitListBound change on every iteration. + acceptResult.acceptBound = + BinomialBounds.getLowerBound(delta, acceptResult.numItems, fraction) + acceptResult.waitListBound = + BinomialBounds.getUpperBound(delta, acceptResult.numItems, fraction) + + val x = rng.nextUniform() + if (x < acceptResult.acceptBound) { + acceptResult.numAccepted += 1 + } else if (x < acceptResult.waitListBound) { + acceptResult.waitList += x + } + } + acceptResult.numItems += 1 + result + } + } + + /** + * Returns the function used combine results returned by seqOp from different partitions. + */ + def getCombOp[K]: (mutable.Map[K, AcceptanceResult], mutable.Map[K, AcceptanceResult]) + => mutable.Map[K, AcceptanceResult] = { + (result1: mutable.Map[K, AcceptanceResult], result2: mutable.Map[K, AcceptanceResult]) => { + // take union of both key sets in case one partition doesn't contain all keys + result1.keySet.union(result2.keySet).foreach { key => + // Use result2 to keep the combined result since r1 is usual empty + val entry1 = result1.get(key) + if (result2.contains(key)) { + result2(key).merge(entry1) + } else { + if (entry1.isDefined) { + result2 += (key -> entry1.get) + } + } + } + result2 + } + } + + /** + * Given the result returned by getCounts, determine the threshold for accepting items to + * generate exact sample size. + * + * To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare + * it to the number of items that were accepted instantly and the number of items in the waitlist + * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted), + * which means we need to sort the elements in the waitlist by their associated values in order + * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize. + * Note that all elements in the waitlist have values >= bound for instant accept, so a T value + * in the waitlist range would allow all elements that were instantly accepted on the first pass + * to be included in the sample. + */ + def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult], + fractions: Map[K, Double]): Map[K, Double] = { + val thresholdByKey = new mutable.HashMap[K, Double]() + for ((key, acceptResult) <- finalResult) { + val sampleSize = math.ceil(acceptResult.numItems * fractions(key)).toLong + if (acceptResult.numAccepted > sampleSize) { + logWarning("Pre-accepted too many") + thresholdByKey += (key -> acceptResult.acceptBound) + } else { + val numWaitListAccepted = (sampleSize - acceptResult.numAccepted).toInt + if (numWaitListAccepted >= acceptResult.waitList.size) { + logWarning("WaitList too short") + thresholdByKey += (key -> acceptResult.waitListBound) + } else { + thresholdByKey += (key -> acceptResult.waitList.sorted.apply(numWaitListAccepted)) + } + } + } + thresholdByKey + } + + /** + * Return the per partition sampling function used for sampling without replacement. + * + * When exact sample size is required, we make an additional pass over the RDD to determine the + * exact sampling rate that guarantees sample size with high confidence. + * + * The sampling function has a unique seed per partition. + */ + def getBernoulliSamplingFunction[K, V](rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + var samplingRateByKey = fractions + if (exact) { + // determine threshold for each stratum and resample + val finalResult = getAcceptanceResults(rdd, false, fractions, None, seed) + samplingRateByKey = computeThresholdByKey(finalResult, fractions) + } + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator + rng.reSeed(seed + idx) + // Must use the same invoke pattern on the rng as in getSeqOp for without replacement + // in order to generate the same sequence of random numbers when creating the sample + iter.filter(t => rng.nextUniform() < samplingRateByKey(t._1)) + } + } + + /** + * Return the per partition sampling function used for sampling with replacement. + * + * When exact sample size is required, we make two additional passed over the RDD to determine + * the exact sampling rate that guarantees sample size with high confidence. The first pass + * counts the number of items in each stratum (group of items with the same key) in the RDD, and + * the second pass uses the counts to determine exact sampling rates. + * + * The sampling function has a unique seed per partition. + */ + def getPoissonSamplingFunction[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], + fractions: Map[K, Double], + exact: Boolean, + seed: Long): (Int, Iterator[(K, V)]) => Iterator[(K, V)] = { + // TODO implement the streaming version of sampling w/ replacement that doesn't require counts + if (exact) { + val counts = Some(rdd.countByKey()) + val finalResult = getAcceptanceResults(rdd, true, fractions, counts, seed) + val thresholdByKey = computeThresholdByKey(finalResult, fractions) + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + iter.flatMap { item => + val key = item._1 + val acceptBound = finalResult(key).acceptBound + // Must use the same invoke pattern on the rng as in getSeqOp for with replacement + // in order to generate the same sequence of random numbers when creating the sample + val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound) + val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound) + val copiesInSample = copiesAccepted + + (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key)) + if (copiesInSample > 0) { + Iterator.fill(copiesInSample.toInt)(item) + } else { + Iterator.empty + } + } + } + } else { + (idx: Int, iter: Iterator[(K, V)]) => { + val rng = new RandomDataGenerator() + rng.reSeed(seed + idx) + iter.flatMap { item => + val count = rng.nextPoisson(fractions(item._1)) + if (count > 0) { + Iterator.fill(count)(item) + } else { + Iterator.empty + } + } + } + } + } + + /** A random data generator that generates both uniform values and Poisson values. */ + private class RandomDataGenerator { + val uniform = new XORShiftRandom() + var poisson = new Poisson(1.0, new DRand) + + def reSeed(seed: Long) { + uniform.setSeed(seed) + poisson = new Poisson(1.0, new DRand(seed.toInt)) + } + + def nextPoisson(mean: Double): Int = { + poisson.nextInt(mean) + } + + def nextUniform(): Double = { + uniform.nextDouble() + } + } +} + +/** + * Object used by seqOp to keep track of the number of items accepted and items waitlisted per + * stratum, as well as the bounds for accepting and waitlisting items. + * + * `[random]` here is necessary since it's in the return type signature of seqOp defined above + */ +private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: Long = 0L) + extends Serializable { + + val waitList = new ArrayBuffer[Double] + var acceptBound: Double = Double.NaN // upper bound for accepting item instantly + var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist + + def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + + def merge(other: Option[AcceptanceResult]): Unit = { + if (other.isDefined) { + waitList ++= other.get.waitList + numAccepted += other.get.numAccepted + numItems += other.get.numItems + } + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index f882a8623fd84..e8bd65f8e4507 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -29,6 +29,7 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -1208,4 +1209,40 @@ public Tuple2 call(Integer x) { pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKey() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); + Map wrCounts = (Map) (Object) wr.countByKey(); + Assert.assertTrue(wrCounts.size() == 2); + Assert.assertTrue(wrCounts.get(0) > 0); + Assert.assertTrue(wrCounts.get(1) > 0); + JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); + Map worCounts = (Map) (Object) wor.countByKey(); + Assert.assertTrue(worCounts.size() == 2); + Assert.assertTrue(worCounts.get(0) > 0); + Assert.assertTrue(worCounts.get(1) > 0); + JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + Map wrExactCounts = (Map) (Object) wrExact.countByKey(); + Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertTrue(wrExactCounts.get(0) == 2); + Assert.assertTrue(wrExactCounts.get(1) == 4); + JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + Map worExactCounts = (Map) (Object) worExact.countByKey(); + Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertTrue(worExactCounts.get(0) == 2); + Assert.assertTrue(worExactCounts.get(1) == 4); + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 447e38ec9dbd0..4f49d4a1d4d34 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -83,6 +83,122 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { assert(valuesFor2.toList.sorted === List(1)) } + test("sampleByKey") { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + // Without replacement validation + def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + def checkAllCombos(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) + takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) + takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) + takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + } + + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + checkAllCombos(stratifiedData, samplingRate, seed, n) + } + + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + } + } + test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_).collect() diff --git a/pom.xml b/pom.xml index 8b1435cfe5d19..39538f9660623 100644 --- a/pom.xml +++ b/pom.xml @@ -257,6 +257,12 @@ commons-codec 1.5 + + org.apache.commons + commons-math3 + 3.3 + test + com.google.code.findbugs jsr305 From c7db274be79f448fda566208946cb50958ea9b1a Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Tue, 29 Jul 2014 15:32:50 -0700 Subject: [PATCH 016/170] [SPARK-2393][SQL] Cost estimation optimization framework for Catalyst logical plans & sample usage. The idea is that every Catalyst logical plan gets hold of a Statistics class, the usage of which provides useful estimations on various statistics. See the implementations of `MetastoreRelation`. This patch also includes several usages of the estimation interface in the planner. For instance, we now use physical table sizes from the estimate interface to convert an equi-join to a broadcast join (when doing so is beneficial, as determined by a size threshold). Finally, there are a couple minor accompanying changes including: - Remove the not-in-use `BaseRelation`. - Make SparkLogicalPlan take a `SQLContext` in the second param list. Author: Zongheng Yang Closes #1238 from concretevitamin/estimates and squashes the following commits: 329071d [Zongheng Yang] Address review comments; turn config name from string to field in SQLConf. 8663e84 [Zongheng Yang] Use BigInt for stat; for logical leaves, by default throw an exception. 2f2fb89 [Zongheng Yang] Fix statistics for SparkLogicalPlan. 9951305 [Zongheng Yang] Remove childrenStats. 16fc60a [Zongheng Yang] Avoid calling statistics on plans if auto join conversion is disabled. 8bd2816 [Zongheng Yang] Add a note on performance of statistics. 6e594b8 [Zongheng Yang] Get size info from metastore for MetastoreRelation. 01b7a3e [Zongheng Yang] Update scaladoc for a field and move it to @param section. 549061c [Zongheng Yang] Remove numTuples in Statistics for now. 729a8e2 [Zongheng Yang] Update docs to be more explicit. 573e644 [Zongheng Yang] Remove singleton SQLConf and move back `settings` to the trait. 2d99eb5 [Zongheng Yang] {Cleanup, use synchronized in, enrich} StatisticsSuite. ca5b825 [Zongheng Yang] Inject SQLContext into SparkLogicalPlan, removing SQLConf mixin from it. 43d38a6 [Zongheng Yang] Revert optimization for BroadcastNestedLoopJoin (this fixes tests). 0ef9e5b [Zongheng Yang] Use multiplication instead of sum for default estimates. 4ef0d26 [Zongheng Yang] Make Statistics a case class. 3ba8f3e [Zongheng Yang] Add comment. e5bcf5b [Zongheng Yang] Fix optimization conditions & update scala docs to explain. 7d9216a [Zongheng Yang] Apply estimation to planning ShuffleHashJoin & BroadcastNestedLoopJoin. 73cde01 [Zongheng Yang] Move SQLConf back. Assign default sizeInBytes to SparkLogicalPlan. 73412be [Zongheng Yang] Move SQLConf to Catalyst & add default val for sizeInBytes. 7a60ab7 [Zongheng Yang] s/Estimates/Statistics, s/cardinality/numTuples. de3ae13 [Zongheng Yang] Add parquetAfter() properly in test. dcff9bd [Zongheng Yang] Cleanups. 84301a4 [Zongheng Yang] Refactors. 5bf5586 [Zongheng Yang] Typo. 56a8e6e [Zongheng Yang] Prototype impl of estimations for Catalyst logical plans. --- .../sql/catalyst/analysis/unresolved.scala | 4 +- .../catalyst/plans/logical/BaseRelation.scala | 24 ----- .../catalyst/plans/logical/LogicalPlan.scala | 22 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 61 +++++++----- .../org/apache/spark/sql/SQLContext.scala | 20 ++-- .../org/apache/spark/sql/SchemaRDD.scala | 3 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 2 +- .../spark/sql/api/java/JavaSQLContext.scala | 4 +- .../spark/sql/execution/SparkPlan.scala | 18 ++-- .../spark/sql/execution/SparkStrategies.scala | 57 ++++++----- .../org/apache/spark/sql/json/JsonRDD.scala | 11 ++- .../spark/sql/parquet/ParquetRelation.scala | 4 +- .../org/apache/spark/sql/JoinSuite.scala | 2 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 47 ++++++--- .../spark/sql/hive/StatisticsSuite.scala | 95 +++++++++++++++++++ .../hive/execution/HiveComparisonTest.scala | 2 +- .../sql/hive/execution/HiveQuerySuite.scala | 2 +- .../spark/sql/parquet/HiveParquetSuite.scala | 2 +- 18 files changed, 256 insertions(+), 124 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 7abeb032964e1..a0e25775da6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode /** @@ -36,7 +36,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( databaseName: Option[String], tableName: String, - alias: Option[String] = None) extends BaseRelation { + alias: Option[String] = None) extends LeafNode { override def output = Nil override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala deleted file mode 100644 index 582334aa42590..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical - -abstract class BaseRelation extends LeafNode { - self: Product => - - def tableName: String -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index edc37e3877c0e..ac85f95b52a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { self: Product => + /** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overriden version of `Statistics`. + * + * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ + case class Statistics( + sizeInBytes: BigInt + ) + lazy val statistics: Statistics = Statistics( + sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product + ) + /** * Returns the set of attributes that are referenced by this node * during evaluation. @@ -92,6 +111,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => + override lazy val statistics: Statistics = + throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics") + // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 41920c00b5a2c..be8d4e15ec4b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -21,17 +21,31 @@ import java.util.Properties import scala.collection.JavaConverters._ +object SQLConf { + val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} + /** - * SQLConf holds mutable config parameters and hints. These can be set and - * queried either by passing SET commands into Spark SQL's DSL - * functions (sql(), hql(), etc.), or by programmatically using setters and - * getters of this class. + * A trait that enables the setting and getting of mutable config parameters/hints. + * + * In the presence of a SQLContext, these can be set and queried by passing SET commands + * into Spark SQL's query functions (sql(), hql(), etc.). Otherwise, users of this trait can + * modify the hints by programmatically calling the setters and getters of this trait. * - * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). + * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ trait SQLConf { import SQLConf._ + @transient protected[spark] val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? @@ -40,28 +54,33 @@ trait SQLConf { /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to - * a broadcast value during the physical executions of join operations. Setting this to 0 + * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. - * Hive setting: hive.auto.convert.join.noconditionaltask.size. + * + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. */ - private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt + private[spark] def autoBroadcastJoinThreshold: Int = + get(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt - /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "") + /** + * The default size in bytes to assign to a logical operator's estimation statistics. By default, + * it is set to a larger value than `autoConvertJoinSize`, hence any logical operator without a + * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. + */ + private[spark] def defaultSizeInBytes: Long = + getOption(DEFAULT_SIZE_IN_BYTES).map(_.toLong).getOrElse(autoBroadcastJoinThreshold + 1) /** ********************** SQLConf functionality methods ************ */ - @transient - private val settings = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, String]()) - def set(props: Properties): Unit = { - props.asScala.foreach { case (k, v) => this.settings.put(k, v) } + settings.synchronized { + props.asScala.foreach { case (k, v) => settings.put(k, v) } + } } def set(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for $key") + require(value != null, s"value cannot be null for key: $key") settings.put(key, value) } @@ -90,13 +109,3 @@ trait SQLConf { } } - -object SQLConf { - val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" - - object Deprecated { - val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c178dad662532..a136c7b3ffef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,14 +24,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -86,7 +86,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. @@ -127,7 +127,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = - new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + new SchemaRDD(this, JsonRDD.inferSchema(self, json, samplingRatio)) /** * :: Experimental :: @@ -170,11 +170,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - val name = tableName - val newPlan = rdd.logicalPlan transform { - case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name) - } - catalog.registerTable(None, tableName, newPlan) + catalog.registerTable(None, tableName, rdd.logicalPlan) } /** @@ -212,7 +208,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) case inMem: InMemoryRelation => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) @@ -405,7 +401,7 @@ class SQLContext(@transient val sparkContext: SparkContext) new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row } } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(self)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 019ff9d300a18..172b6e0e7f26b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -418,7 +418,8 @@ class SchemaRDD( * @group schema */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))) + new SchemaRDD(sqlContext, + SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))(sqlContext)) } // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fe81721943202..fd751031b26e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -56,7 +56,7 @@ private[sql] trait SchemaRDDLike { // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan) + SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) case _ => baseLogicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf16..806097c917b91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -92,7 +92,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) } /** @@ -120,7 +120,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * @group userf */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(sqlContext, json, 1.0)) /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 27dc091b85812..77c874d0315ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.sql.{Logging, Row, SQLContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.BaseRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ /** @@ -66,8 +66,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { * linking. */ @DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan") - extends BaseRelation with MultiInstanceRelation { +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { def output = alreadyPlanned.output override def references = Set.empty @@ -78,9 +78,15 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "Spar alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) case _ => sys.error("Multiple instance of the same relation detected.") - }, tableName) - .asInstanceOf[this.type] + })(sqlContext).asInstanceOf[this.type] } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) + } private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c078e71fe0290..404d48ae05b45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution +import scala.util.Try + import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ @@ -47,9 +49,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. + * + * This strategy applies a simple optimization based on the estimates of the physical sizes of + * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * estimated physical size smaller than the user-settable threshold + * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the + * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be + * ''broadcasted'' to all of the executors involved in the join, as a + * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they + * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { - private[this] def broadcastHashJoin( + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: LogicalPlan, @@ -61,33 +72,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } - def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left, - right @ PhysicalOperation(_, _, b: BaseRelation)) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if Try(sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) - case ExtractEquiJoinKeys( - Inner, - leftKeys, - rightKeys, - condition, - left @ PhysicalOperation(_, _, b: BaseRelation), - right) - if broadcastTables.contains(b.tableName) => - broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if Try(sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + val buildSide = + if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) { + BuildRight + } else { + BuildLeft + } val hashJoin = execution.ShuffledHashJoin( - leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case _ => Nil @@ -273,8 +278,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Limit(limit, planLater(child))(sqlContext) :: Nil case Unions(unionChildren) => execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + case logical.Except(left,right) => + execution.Except(planLater(left),planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => @@ -283,7 +288,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil + case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b48c70ee73a27..6c2b553bb908e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,11 +28,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.Logging +import org.apache.spark.sql.{SQLContext, Logging} private[sql] object JsonRDD extends Logging { private[sql] def inferSchema( + sqlContext: SQLContext, json: RDD[String], samplingRatio: Double = 1.0): LogicalPlan = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") @@ -40,15 +41,17 @@ private[sql] object JsonRDD extends Logging { val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) val baseSchema = createSchema(allKeys) - createLogicalPlan(json, baseSchema) + createLogicalPlan(json, baseSchema, sqlContext) } private def createLogicalPlan( json: RDD[String], - baseSchema: StructType): LogicalPlan = { + baseSchema: StructType, + sqlContext: SQLContext): LogicalPlan = { val schema = nullTypeToStringType(baseSchema) - SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + SparkLogicalPlan( + ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema))))(sqlContext) } private def createSchema(allKeys: Set[(String, DataType)]): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 9c4771d1a9846..8c7dbd5eb4a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -27,6 +27,7 @@ import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} @@ -45,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { + @transient conf: Option[Configuration] = None) + extends LeafNode with MultiInstanceRelation { self: Product => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e17ecc87fd52a..025c396ef0629 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ class JoinSuite extends QueryTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 156b090712df2..dff1d6a4b93bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.Logging +import org.apache.spark.sql.{SQLContext, Logging} import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical @@ -64,9 +65,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Since HiveQL is case insensitive for table names we make them all lowercase. MetastoreRelation( - databaseName, - tblName, - alias)(table.getTTable, partitions.map(part => part.getTPartition)) + databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) } def createTable( @@ -251,7 +251,11 @@ object HiveMetastoreTypes extends RegexParsers { private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) - extends BaseRelation { + (@transient sqlContext: SQLContext) + extends LeafNode { + + self: Product => + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -264,6 +268,21 @@ private[hive] case class MetastoreRelation new Partition(hiveQlTable, p) } + @transient override lazy val statistics = Statistics( + sizeInBytes = { + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. An + // alternative would be going through Hadoop's FileSystem API, which can be expensive if a lot + // of RPCs are involved. Besides `totalSize`, there are also `numFiles`, `numRows`, + // `rawDataSize` keys that we can look at in the future. + BigInt( + Option(hiveQlTable.getParameters.get("totalSize")) + .map(_.toLong) + .getOrElse(sqlContext.defaultSizeInBytes)) + } + ) + val tableDesc = new TableDesc( Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, @@ -275,14 +294,14 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), - // Since data can be dumped in randomly with no validation, everything is nullable. - nullable = true - )(qualifiers = tableName +: alias.toSeq) - } + implicit class SchemaAttribute(f: FieldSchema) { + def toAttribute = AttributeReference( + f.getName, + HiveMetastoreTypes.toDataType(f.getType), + // Since data can be dumped in randomly with no validation, everything is nullable. + nullable = true + )(qualifiers = tableName +: alias.toSeq) + } // Must be a stable value since new attributes are born here. val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala new file mode 100644 index 0000000000000..a61fd9df95c94 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +class StatisticsSuite extends QueryTest { + + test("estimates the size of a test MetastoreRelation") { + val rdd = hql("""SELECT * FROM src""") + val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + mr.statistics.sizeInBytes + } + assert(sizes.size === 1) + assert(sizes(0).equals(BigInt(5812)), + s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") + } + + test("auto converts to broadcast hash join, by size estimate of a relation") { + def mkTest( + before: () => Unit, + after: () => Unit, + query: String, + expectedAnswer: Seq[Any], + ct: ClassTag[_]) = { + before() + + var rdd = hql(query) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, expectedAnswer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = hql(query) + bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + assert(shj.size === 1, + "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + + hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + } + + after() + } + + /** Tests for MetastoreRelation */ + val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" + val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + mkTest( + () => (), + () => (), + metastoreQuery, + metastoreAnswer, + implicitly[ClassTag[MetastoreRelation]] + ) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b4dbf2b115799..6c8fe4b196dea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -132,7 +132,7 @@ abstract class HiveComparisonTest answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { - case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false case PhysicalOperation(_, _, Sort(_, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a022a1e2dc70e..50f85289fdad8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -21,7 +21,7 @@ import scala.util.Try import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{Row, SchemaRDD} case class TestData(a: Int, b: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 91ad59d7f82c0..3bfe49a760be5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -35,7 +35,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft override def beforeAll() { // write test data - ParquetTestData.writeFile + ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") } From 2c356665c986564482ccfb3f880f0a2c023a7cb7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 29 Jul 2014 17:52:48 -0700 Subject: [PATCH 017/170] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #740 (close requested by 'rxin') Closes #647 (close requested by 'rxin') Closes #1383 (close requested by 'rxin') Closes #1485 (close requested by 'pwendell') Closes #693 (close requested by 'rxin') Closes #478 (close requested by 'JoshRosen') From 39b8193102ebf32ef6b40631a949318b281d44a1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 29 Jul 2014 18:14:20 -0700 Subject: [PATCH 018/170] [SPARK-2716][SQL] Don't check resolved for having filters. For queries like `... HAVING COUNT(*) > 9` the expression is always resolved since it contains no attributes. This was causing us to avoid doing the Having clause aggregation rewrite. Author: Michael Armbrust Closes #1640 from marmbrus/havingNoRef and squashes the following commits: 92d3901 [Michael Armbrust] Don't check resolved for having filters. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- ...erences-0-d2de3ba23759d25ef77cdfbab72cbb63 | 136 ++++++++++++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 3 + 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02bdb64f308a5..74c0104e5b17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -159,7 +159,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs diff --git a/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 new file mode 100644 index 0000000000000..3f2cab688ccc2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having no references-0-d2de3ba23759d25ef77cdfbab72cbb63 @@ -0,0 +1,136 @@ +0 +5 +12 +15 +18 +24 +26 +35 +37 +42 +51 +58 +67 +70 +72 +76 +83 +84 +90 +95 +97 +98 +100 +103 +104 +113 +118 +119 +120 +125 +128 +129 +134 +137 +138 +146 +149 +152 +164 +165 +167 +169 +172 +174 +175 +176 +179 +187 +191 +193 +195 +197 +199 +200 +203 +205 +207 +208 +209 +213 +216 +217 +219 +221 +223 +224 +229 +230 +233 +237 +238 +239 +242 +255 +256 +265 +272 +273 +277 +278 +280 +281 +282 +288 +298 +307 +309 +311 +316 +317 +318 +321 +322 +325 +327 +331 +333 +342 +344 +348 +353 +367 +369 +382 +384 +395 +396 +397 +399 +401 +403 +404 +406 +409 +413 +414 +417 +424 +429 +430 +431 +438 +439 +454 +458 +459 +462 +463 +466 +468 +469 +478 +480 +489 +492 +498 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 50f85289fdad8..aadfd2e900151 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,9 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("having no references", + "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("boolean = number", """ |SELECT From 86534d0f5255362618c05a07b0171ec35c915822 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 29 Jul 2014 18:20:51 -0700 Subject: [PATCH 019/170] [SPARK-2631][SQL] Use SQLConf to configure in-memory columnar caching Author: Michael Armbrust Closes #1638 from marmbrus/cachedConfig and squashes the following commits: 2362082 [Michael Armbrust] Use SQLConf to configure in-memory columnar caching --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 4 ++++ sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index be8d4e15ec4b0..5d85a0fd4eebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,6 +22,7 @@ import java.util.Properties import scala.collection.JavaConverters._ object SQLConf { + val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" @@ -49,6 +50,9 @@ trait SQLConf { /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** When true tables cached using the in-memory columnar caching will be compressed. */ + private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a136c7b3ffef5..c2bdef732372c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -192,8 +192,6 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - val useCompression = - sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false) InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) } From 22649b6cde8e18f043f122bce46f446174d00f6c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 29 Jul 2014 19:02:06 -0700 Subject: [PATCH 020/170] [SPARK-2305] [PySpark] Update Py4J to version 0.8.2.1 Author: Josh Rosen Closes #1626 from JoshRosen/SPARK-2305 and squashes the following commits: 03fb283 [Josh Rosen] Update Py4J to version 0.8.2.1. --- LICENSE | 4 ++-- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../apache/spark/api/python/PythonUtils.scala | 2 +- python/lib/py4j-0.8.1-src.zip | Bin 37662 -> 0 bytes python/lib/py4j-0.8.2.1-src.zip | Bin 0 -> 37562 bytes sbin/spark-config.sh | 2 +- sbin/spark-executor | 2 +- 9 files changed, 8 insertions(+), 8 deletions(-) delete mode 100644 python/lib/py4j-0.8.1-src.zip create mode 100644 python/lib/py4j-0.8.2.1-src.zip diff --git a/LICENSE b/LICENSE index 65e1f480d9b14..76a3601c66918 100644 --- a/LICENSE +++ b/LICENSE @@ -272,7 +272,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ======================================================================== -For Py4J (python/lib/py4j0.7.egg and files in assembly/lib/net/sf/py4j): +For Py4J (python/lib/py4j-0.8.2.1-src.zip) ======================================================================== Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. @@ -532,7 +532,7 @@ The following components are provided under a BSD-style license. See project lin (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.2.1 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (ISC/BSD License) jbcrypt (org.mindrot:jbcrypt:0.3m - http://www.mindrot.org/) diff --git a/bin/pyspark b/bin/pyspark index 69b056fe28f2c..39a20e2a24a3c 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -52,7 +52,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP=$PYTHONSTARTUP diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 0ef9eea95342e..2c4b08af8d4c3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -45,7 +45,7 @@ rem Figure out which Python to use. if [%PYSPARK_PYTHON%] == [] set PYSPARK_PYTHON=python set PYTHONPATH=%FWDIR%python;%PYTHONPATH% -set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%FWDIR%python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index a24743495b0e1..4f061099a477d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -275,7 +275,7 @@ net.sf.py4j py4j - 0.8.1 + 0.8.2.1 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 6d3e257c4d5df..52c70712eea3d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -29,7 +29,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/python/lib/py4j-0.8.1-src.zip b/python/lib/py4j-0.8.1-src.zip deleted file mode 100644 index 2069a328d1f2e6a94df057c6a3930048ae3f3832..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 37662 zcmb4qW00lml5Uw@R(IL9ZQHhO+qP|2mu=g&ZCkhZnK^f6pF0yXcSYnE>(7_#ja+X& zPf}I_7z7I7uTNqGf#5&B{QCnG029F0jnSM&Nf8nNc*I~>{`HSIxk3W~f}8*Y0Q}=2 z``^9}{m<7OjT{^u{zHka;eV+F`OmLH08DZd@;A3k22KCb9|R2mfbn-F=6cS0I#znN zI)CY;wsrf*$>`sYWdD+XIoYN1woL&my!UbS&1{Hoq5$NkR~hx(5#pM4b#ocBQj`e}AF+Qsz!IK%djA7H>1%Ch;U{;tu|W`SF9aXjcT zl;^-?i^*str1(K=^1}5Z765nW{w5InPGw{M`jK;zH~e|@G6vbQ6Y@gPeFyMyqV~$< z;S`V+a}(ml@^Uw+;%P@1)!NL)RppgRHb%d~sPpF{)u*|W^6TiRu(vNeuv$r6fmz&f zW1?oEd#1=siOQIaM6fMPaYl8nsc!}!#*uZT`tR}L4W()~lF4`BT8-2owFt^^SyI0Q z=zDi!QJ))ZH;|xv?=joZkcKjoDz#Px48PH#T^}znus&CaTh@vn{LSf5{EO)1$fhqf zo^w6$L8`cz=`iy$GB6l*a%Cdmv#7aqK=Gi^PGVwGQb;-BW)_bJDOIUwhZX8!0*oN$ zdgn%Hs!i%~Q!xa9q|)bDcN_&44`wzu#a9?iq#G$n;fX}@qkk#d2@4|WdA&^%KS81k z*yFptZ$T*2{l2JvLLwcC1Q04N+GEHLz(L{9%PjuL-kJ<8Y?MksFRuzBAkzSLT{uX- zoeV~`MeCfo2r#9FIZACrZ1681Mk+!=}aULvzNXL^eYp#FUxJL19RH+ z3UgA;mOiBb-GH{}^N5!t&y0KSGAYD+E1a|gXHVI08CxKv6o(2lLo6UAJL=gXX7O;a za#soIbnV2FxmN+QXpJpbs^}^)EV0=N-&)@^DTeEmFPh*Wv41Zy7L_ny3I%Jr9`Y+0iU!4K%Ayw0uV)2sUK;hXe7Tbn4$YSEHi&>au>e%jr`iC(k;U+;_A zzS@0}4UUFu=i)wj@(?@AbP3KlvaEZr4g99G4?M~g7clMHF-GLN*_(O@JNN}Ej5z8r zeXovwqXHin%7sn^ucGmCkMITub(APK`6{4t$V=FL8qtk@PLqI57Pt_IlOcEQ0_Ihh zU_%q98TAetPtpB8X%>WXBF`$10-xcydo5cX$heJ=3i$Gs}C1hyVZ%dH?`8 zf7ka9%rVh(G;+~%`y1{wakt3-CI5;$oafH#L$lwn$QY;KJ&F9Fv2`d}1=-~EDC3Kl zNMV^EMlKD1<3vT{Q z7>^DVqb)d8jYmb^`&A0Om}ugV+qfiQF|~l7*ES-No$y1owPBr%I92Ki?C zA$!e5cOASf)j>1{YPJG!ujp-y43K7CQDBA%mKW9)^Dl&k{HjRm9o92(An~ zvBe}D^`aO?ZzRe}BPNyQQq+pqGXmdzIJ!~kr|G4lE(raJA)7VwytF<*oJS64GKF9VETBXd z2n4hHIk`S%aFVkSkae~`h6feunlA7^uqoPNS1edh(~ZL4ksPTP#)%}n zPKA#U1LqSwEo>>#iMk@)dotg^xR^^Swft8Zv=K~(Xp`Ot00nP2QR(k^3HZHEw??*U z<=>5j;Y!xzTZHU*>`(j;Um?H4an%PteW=Syt~-pf(IWTcUY>Cucw-6>28N2JS=%n_ z!AVO!i%lFed{l)CsY1TT9VSchYG-`Kz!eh_OnyfI<6@xOV~L8V88n8^)(f&9{1BH( z*3VxrTLBx;Nt3^OPy#iZ8i^6{hT^>2#SvVlY>y0}kc5%rNo!l(Eoc*9%%jYyfI8uw zSK8^fnP)AHBQOB?;Oq7OjHCG#>52A*Ll6^(1Bo3!Is%6<6EX|QS@;3!9`(kcmpw)P z!($xtXbTM!3&~i@Ea43Rc|rJqu4zW9NJ)w(EcG0-tat{B6c`vs8FvjTvFg{vaBEl< z)&A*L-jcQNP*O9@Uh+raZ7Aa=7*zWXCwz)7EI)52YXolFO{jH7WST$<{D5Yv2`|9C zZ8QL2G=FHu&kRsIyfUk0&>M#H9igP7Zc-Mu3GLYZK?18?SH6L}0DOkNaJDwi-+dcA z@D3dY7f#89@9Kz3)hhyYHC3&K{uO&?KcTwEY*(fFj68q=XT-rO(Z2>Vz~%fww^rd{ zf@YRuledQA0&m;!gsX#_;K53sR$`j)XZ6UG#b&dJ)5pdzd=}2yd7Z_zYCr^UYln{r zZ8MF2)`cbWQx7z-Q#DQHJ+(ov;G>^*!7I;`_FsaIDk145L=GTkUb{v7Z)#W1!&vHV~VM(xxk^d33`o{LP|9 z9C1*kmuv-?g?T{SiH;6mKRq1JL0~Ig!2N;lo&QXWexmb!Cwn&=@sR{*1rDM)pC9th zW#p^XQNT$*AHBlSctV0ouRZc0RikE@wYqou*`e#! z)ODeo0gR5e?cI$nYY#il-`E#!FFcnSO5}?a6U`75do&YQEDiI*;up4$%l1Th)kF&J2P;8k+P&4GSm_S0`Tc>5m1%Y{Qr= z)^$`V&ClY~q9(#(4{V_+HFuj>lB(DpNN7n(5kOUpN2{xft*$^92~@cfh#X0pzZ;s= zIXhysrlVARUheYZw7Zp}?69*)YMDi1*RVnQ$9C&ME-^@4?XQ7)kd5R#q|qKtZOjr`-nf zSfqzZn|q`)^kFo0LMo0EN1X#N#9{T~WN_fe-lSHiWEM%oX^N~X%dW3k<9k#sDeU(p ziwO)Z5l;c=Nklst5SVzj#&xQ>7*_9~7A2p=erUE$BqbOWR)BKActViP)?0~wS1)m* z`T4k9p}}TwR~jlU5#Od-ZX!3P_%%&zop0$qD_z&H!RlzrAz09|S<-mie55z#7Zh5B+r z3$k~pB50ZuK>2Hswl9 zT{?G_k+TPE)b90IoU`9F1{1f3%X81Yh^E6H25jThCZ9W58O>zcL!vO^&0QCWA`t28 z!bHstmoaBYCdnk`T)K~O7T5nB2nT7B)Tc&1=#M{yI~sUsi`eCaj;XO2s+^ZE*tz3) zwu=7)n>yB_NX&6{S}0XUoRyy8OeS`bmM7PtI1w|sg;+~2DAH^wTNzR)-Pm;NeV}`R z&?+f_lnuSLeTHqh@4mOY6*_CLIZ$JijcQRCNTJpTlFdTG9&}!vMo)uVv!^GA`9!Vp zcSE*QB6{SBm=$<|5~0h;6{EUdTkrwbpor^Z>5!Jle zIdf2PtjRDdAGb)5tN4Zp_XggC8hPNwy`_rHY8h|ztjK&7!qC)0ERcMPSe?@+!sn1& z+H%5C_n)EqVqW7{H^CI$h7l#NvVmF--zBg-OVP(@qy3SuAM|8PV%T&0U>H2Tv8sGbyj;) zI2pNbhl0pCR)?N{1K|iC&V>csn3ca@tE>`|F2+WLFfm>pW<@bjJFcT%o7QahX*rmZ z;ZS#HCO?IZXF0MbMc*S;Q^yi$T!5Ti6((F-+<~F4@6yPWu_-_qO;M`4Bl)wd*w_i@ zy9r<0tpfMmVZjW$ifLC(*H4%QX1Z>G#>!91O(SOA!>X{GEh1cKFh34R?Hxn%DHAl8 zIjH>e$KQB7}kHJ zbvTkF^7Nq`Zg?94#i~T4<1>K2unW$}L!_VW(oHJ{{pKv+Ps0$D06Y_{AE>KP;Sia& z1qi$duLzl+j$!+GuMowqtLf3QGtQIxD)t3HLLa??OseWy6RxjpY`9pAS!};VfQ`Ya z`Em=-`92q<3)mm>QB+fzAs1lg)F@45V!iK#Ie_F7BA>&$NAQS7X%Mi@nhYwk%R7)r-T9Jfu)uHVwC}2Yk2w9GX9=a-;U^;AaY;z}}K+ zrTl#+zi<&-y4&zdyt^=HA4Rqh@R-IYlr~5nGfX}YH{vKvBtt3Tn2c!(M*@yX zRA-nT0q6l~e4q1P+(Z##51SC zQ&)GUjpt;dv$I!a*hY=sEZnFXnD4R39a|z@UaBPXesbknx?OVrR*U8?3z;xePV;9( zgayTPIa@}9%+Mjv#r%0i@6Ca^X;ILj#t~JQN-~f~`xqh<>G2T8Jg?G8NFKC0(sk@O z?hddwE_dIC<-P|`%;0+yv30}E3AqoNLCUu0!ZOHtQ`f|DU9?0V#e49xx;3e*=!$NR z^Sd=QNQ8maSAp_{qu$Y@BN66a{Fg9RqA;i&1|wVDiYy=y>%7%g?J=!$^YUj|VF5kV z1~r}Zh(YvN&vds&e8B2mWtY73BUi+pzZ)0@3DWRjvoLvK`B3#oq8FwoS%)a1snvCY zrvuP4k^HSaP4Nhg;Z98R*3)j-x)mDkMp-I;<0#zl0G0fRb0|g=vhfe{PkGn)({#?& zvqp_}x0X4DK7KwY33#bMB>zg83j6u*WyAQ3bmGutH$KODG!|ns%;OZXZY6Gi(!TIs zsOEvv5GAOt7nEIZ(N3Ws!kT=semfN_J@gz$+#im*OO;RcS;GzCsi8x03Eh_$7v;PK z?J!XdE&*(h^qUq+&eW$E2wD3+CBof@;&M>gSv zhT<9RuScF-Iq)!Dmj1QqJ&tUWi0TKfZMhQLX|*-I658o8C3S-#cX!>6kpl4|dGS=S z=%9xdf$>spy0^$nIYdGGF&|0?)$aFMcMeucdkvZs&#F!<%-UJowU=y8Uh7*3d%u;U zD`)1v!T+Opf(4fcsw*uMG4oE%o&bEdEx#%Jn;z{Y(C}dZlJ% zv(60vb*0nWfQsQ}ljQjN-^4p+<4Va5z{KYeB$-*xVZ`8%Xn{)l^AeDPSl7|B|fN>C)!feoT& zZG%hNW5PvmqSpwD;Q~4T*6co8PLiP$l&;<@Oj(UR&GF2h_WnI3jI3h-tsfFClhb4{ zQDuFy>he3`DiuoKpL+vbzdSKpjsr&JFp6A-Hyq>aa;W!(3w9HJiL+TkPy&dEK?DSD z!W*wvqeE0NNk)hC&%A9Q7pGe-cwQ0RD9q9vhgOhfDX=S;C~o|SC;Y)V0gTQ2_4^rAsrr<3Jkar1t!QJB*apGLP9)W#1VfRBY8!#e~#l_7(VL+*=k!5;ja1fN8 zRp5Gx_+)bRy?+)&YHCJjxTM6?o{|6fdOEn^_9s?%Tuj2OE6`j)Z6c1(Pgs!x;4)G^ zFDwMqpIJy-)d9WKfmm$DC^C_}oa}|0>ud|L{`+utNvvq>#GY7TANq(o0r_q zSDEmslK+imq4B&Yt?IV@QM7MX^s5IpS*dC)QmI1N@&IX;%FmSdgFC{Y6hsJ*QPyb_ zfIRb@Tr>^Y!AQ!Y0Awry(28!AOiJyYCL5CtVT1lLktk3a9-ikA5DHTmDY-0ZoN5~C z2!lHm3wpr1axy7zbHxWT4X#b)D2s4kS=UCiaudMKivKObz~}RU?S@9E1vJChn6y~) zeAQ7vQ=#XkjZwY`s*emmBZ0_-vTo8UO_sTpg325})SfIFE17$hb#D^B1i#>S?6G<| zQ1@Er7{ZXQUF~grTDZM^%x`H7Cl1@ya3y;JEx$W+&*d&kD~CMsaAa|cSJ5s_qZJa# z&V0U)ku*dL-#A+8Qgbq?hc#_L&g5IN+Eqz?TSeHixJejOrHkh1z>$;}B%{1-f5`-U zFheC46$_4G`J11SR_S7P-iGS&D8R6hOi2W#)tGA;qxilGF5!i3SxcXfd7S%?DK8es z>&sWo$-pPVKlflaYI0IH!ig}{q-0f;S(S*-Sv&gzZ{Wg??is9T3LWII-|vJS(Pigk z4ml*UAJUzbmH>rH4uS)dv+ZfuzHM8`LcG2q5Je!$8xB$hD23lDJ7F`$T9*NyReE&$dd%<)Mdv7M4|XLpsZb^Mk9 ziOM9?B*v;{?fLXUCKnk)^A0t^JpD)+Q+}X$#qH;k-P$H23YwVGbGq=rM3WRHGEuWN zJSpRqMELgg1E$w*NcG52KYe1Ul`C}YHsKOZFyvR(&ER81evR6iD)kptQjL2TK#1kd zc-O!}gaWhYSe;FI%^(Edo9tJ$M5djxn0xRci>S38Ih zoG~NVuyoD?!BlC@5;Iv!p;Kfb0tSl#GnH~>3swoO(md8ppNKNodk03qf1mR85z6cO zxblIRtm0j@R!C{T+#_vq7~x$nk;;g|l+qJj3L=7Te0dvXK?=qM22A{XTSSKos_fTI z?4(e>G8~Cc>j(S5lXqWHa+B+4(2yi!VvWSlk_D`$nFXLhGRZCe=Ehj_Z^U=>_<}RC` zI-SU6&%0qQfu~D~%@#@iB1-Pk2x0S)2m@%T{W$*NoWx*7)F;1+LHZ(j;Q+N-hX;9l zb=7-Zu6B20WU??2KtMq@d=xvWgZyUXbdSo-)b z+9vr|!+-2iNxq={2buoWS&FB+G5GjHrsre;f2j1|JK+D;!v04m$kNQg@oz->-?g9B zwHz{7k-eU)WwXF-BYs6*Ixp34tuor(*QOCGv}auzaz^SO{6r%_6-4}gPZ|gTh1hFS z+X$JI91?xbt|+J=zyk@-Z;%E8fvf|3>LFA2-Qs{r4ZhW=cqEA&fdVJ61qZIrs zaea2TPYZ4hi>NI8DeU6i70imUu+}IB0X0AKMDd;_Td+%O(vo8hiC z*H+&mep3j8)X`%tY{k)Hn4N&z<2MxXA@55|uU7>bp=5;cg{`TcGOQC_nX$jx3dY|e zsN51B8imuTTwzo|1OX}49F)B?^c?jXR9eU04e0Xwet_jT5a z+A%W|HE5{$G_ZGH8h_aZe?GPCL0q88owcZdqBIqcXc6o;dzJBi`MriogN2*`Zvz3EK8{f$vj<7OkCvu}X*Zp}@=) zvX>M32^1HuiWFW+!gJ2_nD5&qv9DGk>i>5+?q5AhPAMS$!X=UiHWpH5-vxC zJ0$@eRR6+zLrQ6cBH}%$)BX^0w}oMN5MU*WlV;X+w)qPN27c9DL(wV%lBW2C zeX1xL{4&3ffjpR@aApni-Sa@d!^4BwP2NtgEzbvnvtMvge{*s!{z!nj>>lfMdXN!B z<4y{>bfnk%q$y^nW2(klV#mnOF!bUH^ktPV5fjX=BJ*Bqc=*RJMXn6sq=gHlyr1zhk6HIgXh<4ncxSJz zzJD?YeHMazex9?{LAL8wkh*6&QK3VYl1ezAUvkTnj80o(o6n;+Q8bxigY zQh<)3<5Y2_xl_)kE1d-?Y+2|5O@ z>cT3SJzNO`Mk{?-?OEsIqdtWq9m9M|j-QPM1A(fTey37!aGoizn}p9%P*Frg27BC{ zm7KM?d00umu`)tJIy%7;4opR;m=uh$rD_U8SuLIYkg04XLs%J^J`Z!F|D`o`gqJ79 zxZRL7B$kR&F$HETvuZCmh$>87P>In)11~qAPA&$Xwq)8HHe4ClGVkQA70 zHRbR<0Dy2vGupkjpx2rU#xxMtO@={SSQE5hE;CCmHFKu)n(UeF=W8KwS54B$H53Qf zA+S!m#YAS6w;IFV=l(ihq7Jfdg=FgJsgifrTj&l{S_{Y1hX+kLY|qDohk_z9fsYlb zO^v3MM`!S#%I;7J&h%qqWWkHRE72-}MTa1&8yPi^@_7^`@*8Q1FlOehdaC@QKj`w; z_fhLNLMP@Hw1iwZ!Q^15u8id%l8!Ljst2{IcvB|X9Ggo9JF9&qF4#I!(TjAjN``2* zG5SLxRa``m>}rJgo@=X(LmoSBDr(N2>^n?vRrHLHLKnB5wX>J6M*3p0RFTQ+9|iW* ze@ium&l$uH!SXpXkaxOuB}@>tRyL52<`OvoV&}qm^WGUBHVIh9Br zRoeR0c4c;d>oTcze19jK;tR;oX{B}-4k8uoM`%&?H(W5~at*}mOyJqExn0&V40?%h zgJ)-++gpx};k2>z`3McLsUDEk7KLwMnqgQ$MTKvZ6UKcd4FoGRI+?v;Y;3CVg_aV|=oTs|QIw?1gy; zNg!^sQo}fLO;>@J@%8HAIyfSiA| zDqttI+EF?%x{c`&6pTttP3JNn>w?eP_XgoC`kdg{dVK*vOuRARIPlHKK&3x0T*FL=(1LVU6S)*n}HC$psrvOC_0LwN$t zxeveF9UQjgJ{WgUlC{w37rWvr-D=GpJ^A3?HdWv}HYaL@)fGs$A&B{~HuM_Md$W*n zokzo$fk051YtIBNKApM&|3*H|_^HzLu~keaKV|>;shLF9Yee4lhIn-_`~?ihZAfmO z)a5Wk?d;USTjJPkeavy7!R_h2Ih#pTrDK5-6)&ePAjKmJ{xk?SZ5jG(o(atSb5)!R<_<=x9%Rpuc3P&X7RG+lh1ucicBI=F#M;e>I2eVZxJ7GL#u)o2hwH$}lI zJlgGOYd=?ACc%fXTW1&7AaA-iKo)uNi7clKgis3C*_BM`b*Rf5%ipiz-kP1b*PpS% z!5;tsSbx8UR!)|VW{#%zMtX+-i^0N0m38ZNR(P*vm71#n!X6ZR4aeCaKPN0Ev$F-h z-^gmnfW4vyA!I~Qgog`0-2!rTMl4e4Lg7}RE2yJv4r~nRabDZc!ceJxkt6<+d4%oN z?}j!eF2na<#!0PlXE3eCV;jD{o~N%6sbpzL;wDLD$g<+D+Tuo{N;iteyr3+(sWV=E zRgK6=fMi7ECOwE<)49ySeMO9T8fFaB*C)U9&%G_lv>URAOH%qX; zfZG~1sNt*&6px4)sFZ8;DfM64?a%pT8R+{~Mm`92(M`DKE0 zf*&kXF+H+VdLvB(IOSi)?&JX$b5V2sZ=ytQ)I)w1Nrvt?OrKpnid`$I?@br;Z&Y=f zI-Zg`v-H;?a~cJY=ewY6StKVBSy47N>BNtZOW3pG;gCSOv9R%kBgIROggo!liA5XO zJ9GaUoF{YpuF@vHW#l_`RUen35iVfm{*97kI8#i+Ti`$&@c!q-wu(8R)8c=D===%Hn)_h?KJm)F8tkff(VI zx5DZGiM`SfA}y1HTLqe|0x|`Cda;l$_mYB~3M20^!I6?YRN5fAH7}o%abJmwI3L$# z-ul#0>eloNLQtc1blu95RKVM!1fNXsUB$BL?<N2#3(@Jt1Zu_s z&V+)_O#0~TAQrS}MiTwwOQo^*Y@`PhTPt=!T2`9ZpiXdCLF;u04qrp2UaA2wIi7wlUfewHq zIb&{qH9Oveab0DS#<$xpr3Mli*739^{SDE^<1bl?I=p5NZ};Z(j+McWND|937dg90 z&9{skFv_{JjcN|o4onGOzz@oI_^!CjD)q;-z`s|Kn+PB=RLRP`#qZ#}I}bcIV5|HOpM3gsdEFC_BGD3`(p9FPsC>%4 z`(y%XqM$M2uAjqf`pb11tP0Q_tUfBx8_f5VQPf}mC0P?Z@T&#-G~aiLaiq&6fApio0-C+JPf*LDnh+Qebi|1u{3h#Em*KN;u_47(d0 zxq^<5D%#&SkrXPxe@d=Y1|Tj5Ph2ZY809b20Wt;*>Th=n|3q$(O`glvH^R`LSb>Xh z3sm?czfprTnK9;djx?FxOA3~N9HMt+SrN)mEbC%wd()*!&oQdgYPE{6b9hAxyByD{ zn6pO#t32_8y}RT&AP4x3^~MIKU$rk>G)N)uCL{^}! z+!{Hle&*nXrJ->}m{XkT4~!k>H2S-$92@Sv6??bwd?Jx@L^dT?{66|MbEWkHqgHoJ zl3*RbMh>k}K*aV-ZAMXXm8bLIfdg|`sMMc>>!p~3c_56xAp#)gFZKtjJS7(8kA0|1 zl_kpk;7N1kew)g&XdVZPhG9wV;mW%cVC!o9Y*dak>w+~##Y7Kv7`#17LLa%pyQ7wu zQY{9c0^xV-Rv5a>7GM>0)-G)obJ}r(S+PR#!VyyM)SkA7x~#wY067%`JfC8}M>k3e z|5S~ntC%IuznwgUgk(ScCe*}Ef$vQ=wxd++>OPy~IX*qQFT0bTtlPhMeKwLnI%FA@ zqsUV+{em=2J@rCrd$HG{XBa(J{ls%(8;>Ewha*q)!ZB3xzZ3`AvaK{8^f_2^Vgs36 z=A07|Ru^Mun(G^OLMhon!`ZEW(p+DIP?k*hd8Q$fLdp5KKSQK1&-7#+<-ffONrfi-;G;V@vcVu>j~O&1-QuGTI?V0kJ%C%nYaqp!PEQvyF`GL_pH3<1J>-SG!j1whTto*-1< zzu*+C6PFgmlRg<<#=Yr8QlceaDWFw{bS9s`ZsIHtVIQ; z(>>j*T4{ZuT_=KRlStF8$*e-lo1r7{-(wDRrJSdUKV}sAV@CAqx2`$xz zQaA)IgsA?29A5n}wQ*%?30t%>);OK5*zE_kRA(4R2v$+)Y6v6LzjLsx>wu)$5SkoMq~ zvX7KE!SRv)P2FXC;+UX5&sPa2eC12A0VLI2(uf9>Hi3;b6i}yyvw5TB6Px#VmYrhp>xagj z>a`Wyq+b=QrrCRNGHs&y>N=q2CU`;1x|BdX1G#Ii?ooL=QFJ(I4phkY+=OtfD*WbC z0bx#8RGq&m8>Uhw9cCHRMl7u3qD4=$*)=1n8NPegyVSFo+Z2Ay!J88ga}zvvFv+*4 z22w}1rKrpL!2ZhCL&_!c`8k!^&G0Px9gr>eG}GFb8T^s+_dlub{hbi>uYDhc|ED5q zY-X)zY36QZ|G&xJ7F9Ma{$y`H&pLQ*AuE)rE68AeJ>pH=KiYbixT5e3UpLs3DkbC5 z+Pb|-eAXY=3_qU)^v|Wc+ut`sE>>{Tj2HZ2ch@Q67O&++7Eti7nwu<9OgnQ1jW$E#kTCs6|b4cHoKmbjO-3`z+f|!MRDG7 zYH#fg-_@O+Mf7oyM-^0$Ylas`S(QI~x*oZj@6@6credEos&idUG7Me;c<;s2Ygl5F zcVD)5J?amu5IA#nz-)-up3q6Ylq7^*rL5=SQkDs?!IvIkr4;o5dK@1V$pz^bmZb6* z>ut}Cbxz#iJH=ZuduG*>lgFC(VM;6d5r6fHWdeq#O@8d@BcgDY9?P}$YR3)4J`DO+ z;gVRgM-)WaSOS|sB`bzl;)l_r|| zr>gMxeE|Q%wbm%6_{k@DHue9wRyyo|z3S#*^!HxF$B%ZozvN%o^-@JDW=|Nt^Iat| z6}V}jY3RZjjshLFAT@$BT zyVF{rdEH4EDyWyJ5YiI?>h0QNru(TH>sTZ%L#mv+M=`IDD3Ei4n{xiQU7y*+`=n$y2J6E9{v36l>L5Tw!uv_zcO%S_+)7yQ_WQ zrC%rb!~5sf6K@2Hedp*Ty_Y!Vu-YC}171Aq7xkXxoPZ*SV%2jm;2UOMeNgP&=TdDE zJJ(5vn&HFZh|a)@ObI~qbeWngb+8(7K@hE6x?`6|<)KjWt=Aw_KME%|cb_X1kBq`kHem+S z^G%8@wc80Q7br1ONW(xkJWwh3)C|W&;sa~`z4glCkqc`=#;aa5tbDf-%_w`Jp^9+) zw5sJ&h_3+qu8pd zxjQ{zn)Tv8bBd~!fy|9+SxRebQx+$y&S+@EVTZ)=7s@A-Kcp(-x76J?uR1)8?=6!3 z<)A)}o_#&LSdI>iq4~OB0JuvW?B_u^-S>H(L>@xQ3qHYjmD65`MQU z(1*t76N%~e5bfjqs$Mymv%#iYTEVI7Rii+I@|x4nV~06_R>*z$k5@P7zWE>)~7 z7I@*mvUTVd;I>A!WDyd1dzHx`x_p;p@RvIc?N^=aF`QXCv-f;loLO24pXD`#kgd42 zi0|Jf7#KFVQn^$4%jv1nAU!#`(zqSsPj+^2-3#i}aaa7+Mre-_VQ5)*;~~F5nF{@X z0i#P574AC7PbL*VgR-D9v4IdX%<&VS5{+r}le*`{sZH>Tx%M-1X7%ic#9#}{^_gps zX7$j`H}U!%LTT55iVsD}kS&)7u2NQ8RoXEFYb%RuloFL7+7h#Z8uv;1PDUX%4yw(B z`WI1=)=)Gud48)Gp-m2(KOS&OE2+t~PMX-Oq?3m87nN+i0gu~7^@>hT(>3lB0=bj+ z6R@u5IfHHRHN%@v2?l6bK>$=Nx%MKif_5I3$`uD1Sy*;lBgD`Zw=pj!Z=e!NrNKKS zYe(*nROFx{(}C=9D{DimqLM;7(Cg2o^Vu#90w*c#<0_(llxUVT>b6B#>=FOiT7nnv zBBd1VAdFypn}CGOR&c4I>#||5=7wFg^-BtR9kckQ+*B^rNix9Mp>ITFyXVOyh8fh; zwPYYOyq|qw$`4g6+c?5YE`e9GCY*}g%0zlcys!lZ?kol4b(06-#DYy)N!-B)zlgaX zR|Hxb9%|kiUt7@lyv~FPyL-+;bI%;&S6pB><8igFjxQyC1j?AvQ)k)aOe?+#abvkB zktRn5!-o;yWz%L2W(6l1Z@7$nmR+cyomQ3q|4ku&UxQDGt zy4&OwC_?jsif-Wp*)e98?igi)be+orDCx4;&%rN3)0#J{EdM)LceXtDl8e;m1-_{E z_l0?aWheZlGnV^JkF=BuC69xaxGYbPbkg-$aneCCr) zc?7BFoQr~H!5+-$?ie;2OptN2rWdFg`;*p{Sff_4m91IJ+7s;@uQFSsdKy`~M?sdJ zZ+XRtUKM0c0t6(s=nZ0uW{tU%dQaP9-(R#}F3Z(A%M*2|Mu zOK{dp*u@w4qf!uib26M537bCH8`PF4>~{SXJUwOGV(ymlIR57Pje_>BON3rP??TcTuG6SXCiz4#%;~ovTgKy?AN*>fZX5uxZ+xrUeYXvk zR{@`VG=qnpi$20hRn=8hRYgUGjv%-WmSBQ@`k_n6?BvO0FV9{^=jZ2Vy?$GIL@9MX zWwlV6^w|kn#dX2nQ$gwiN=MY(Nyr0T)RvJ^r}*)#8hh)sJ;LOeK0=X4%6Z$!XUy%^ zR(r$5+4!5#4=aUZ40s)RWSlD z7ed0x^ykTW3eC5f%{8so;WziU@vDP-?K$hKNQO=mF%S?tt9Xq4Q^sRsDDzD zS2bzMM1+;gE}s+SDwgvn@y5QX(7C}gsZ=O8kJqJ!inBPYQ3`FvP|MJSwDqYo%WwDR z;^vk15xCFBRBZuCDt(9*!WE@nk_~jd{_H=Y*2_2Yqo2GMeY$;4O~$iZAx`^b`wLe&WfL zE20WHiWPgN(6)7bdD+?Jc5m#Oq?z-z4sY&t zf0y&+eQo%1baHe_%Er#timt8Ge?C=F{f*}h`Tbm8(5~3lxkT21NHNg9X;dB;J{Hf zz!%^HVpqhi4IZV3H#isvHpMKB8h`N%Z>D0V8@_Mei@BA=ddx7@zTPL^b>PP!<2d}! z#;VZkJ#HsP?Zs72?+hg;IJsu`d@(s5=f>Zm-Dhb>i=i{!@!s*HoWRmh{Z%*7NY6+~ z!beaKkb}Nr-5gk7AoDKH&W_&knK)T61w4D_2M<3+iCVkdF1*|mYF(O;IY`N;^O>|eha#2vm3Gg8>uoo@F${U5^4X-O0yx}sy-wr$_B zZQHhO+qP}nwr$+8J(HV-)6m?vJwK6uT@B>Xw*62JD@;w zV==ZU6!m*ZRo9h~KH3Ar1wh}13n0_z3WMC*hlw7|dY$5j0V(Y-QdP{W+t-YYb3G2( zk;YoETfX%BY*V5D4-l5W1K9IQ!Gu+9%K&#Z_io#x=_>r{n$=*EDuX8B;bVkKxEEw0 z!pSW>)no+$m8fD17cj~kbWl(!gOl~9`L2LV z&?^FGJz^@r_W~iJ0TKqrX3PP}Xfd!G(F&ks7eG7c)$qukvQ1b}6?L-Y&&iJL&vIC? zutnD~C`yBbvDTWY9GLmn-j!1;0j z8VlVj3qqm5HR5d`u$t=1S#@Y_{ol}J@It>S6!=tZo6bRLr`-c#b?t<39uk9fqHjL3o(}-kG1{!GjG zJ*>w=IVY@kFyNz8^E>>6Y7wN&5}?b6e2K)^GKb&1vV06B-@rQKrt0|N1xx|`dA0qq z`CODkj=Hsw3*YdT!`4WPs3u#QcsMxx4Z!LfZ6@~w4?!laD4%edl@#T0!#XD9>LPJv zhpsMbB4iP!H#uZdjXWx1*9EGRfys;vnGA5oz7=vl47z0_m3aBAqI`|nKdVhnftO|h zWb9UH(}pwJ3HMW~*6iqplQadIv@K9?gGa@vMM((7fSNx#!nCH+@@R4rF@HHHB?i%G zCJ2?C^*81e?t*k@cLRNe>yS{cL<_Yc>pzBz^h2^G<8}YUjF!G}B#xHp;}H_Zo@nA8 z4CVqRAR(J_GHA_MJc}9CuR@QSG6-3aiKJ#`2fp*%BZ6I8JeK}B1shI$gtYG^s#~H_ zi;l^xx5E1BAo?kJ19w9qD@^`&dkfP*v6Bt6lk+~w)bnL0+3_s0w6Tvv&Iis$*P39% z@J+?L^52u7RGm@?0C;>O;>6DF0c%od9h&Aq9ByjYzWttpzxJ+4LU-gausGOe5?!dP z``J{0p1DZAaOpD%CNY7|N2}5)eN-j&9WsCVqr1Y!Jd2cIRk7Kc`hj?0kh*y?HR;9m$K@~#msfOp6hesZ8X2pg5ipdVHdqR+EsnTz7(;qI739eWX7vmlsjmO+1HRU)Hj`~ za2Y@~CVn!(1-7+**2L4=uPu$rLULOi!U!OU9%Ya2Tcy^0aYBJbq7O+#R9j9mboBll zvG!v{=Zq>;YKJLt(A3UR`l@OaB?jQ}{Vb{zvq|%g7P5+(dA_7J28^ys)^?d2A}yPw z^F&ekDGp?Fac5d{xQJ%N?b&JJ1&C^dmA2w^7n1^ctztthyhZQK0-^Fj#*p5C5dM?p z%7BCxa!ojZ`r0vTDp(cZRa~El;R3du%E|7#NDU%qy^G!o6M=)S{8f%IkYKuujI(PsK$_QK}tee0+4q4d?#m z<4|=TprS9Zo;?j0Dlb>YOQgRw(89D3`nVLOXCkDXZ+hE)EXu$1Z>ZwIy=upqXGD=# z3z5K|;BEbp%tKO39h(LUTF-NB^`owZh*3!R_4M6m0c&a$c5Tv zeO%7+S8Hr`z?Ihlv8?AzVf)nO+}v#ldpVXD$bL5RnEnVsA=i;8hbn9t{uR29`jN|& z=~50ERjqx-RoqioE<+wpcHclU;$Vv31aQ`=E)45-H%w#xqDqN)xq)eeI0KsDNR@6R zSG5LeiJOMp06y|Z1ktxL<)`UaQ@0Tk)#L}C%0~YIr9zP-;QUYwkvij7$M91KqC$5` zp`E}_%fB>17R+Iz8!0>C$ujP5EuLzKH!Z^n-eRTn2~T{O$Gc=yt9G_lk8fGzc66V? zae3N&AcI!Ju$|3_*c)`}af}u#iiZ;B(4a-LiYz-vYmFyM41CbYAwdb8hxj93vIto* z!6bX&3YV`h4pyg0JswjxBcCL!k1*dliu}+xso1YBU?_zS`qgfuoRp*5Bmy&f&7Q%q z(l}lPBZ0xb;D29?0c|j&T0(;*RpXo*NOZp3&Jyp3uiW3Ckxr{lsZNND`|HGjdi>2g zk47{@-K6HE?~PsjB0Wz#s{j(U(mk7sAuKX;5NM3+5r8f07PwVuuLOY$f1Yku*GUk4 zyh>6{?!U4MZ-h~9&jv8E-{*rO8c}?}xCv$Da!ZBQt%l-;|D&Yp#h}Twf?j8L9v`Q! zdsfjR_rVppF>uTIL4G9lth6DLpWTF$=dRHB;}{lSo{uu)HK>*<>vc(aSW zxON%WyLgub+hj`gr@tbek;qisoEaKckAb*wnZ*jsa4u|T+kOXV1qdw*O?H?j#G?Pc zL?9d+HbcO|dU~TuqJ~6bP?Z4kb=d4^GF5MB^UgRf)gBG+=MDdF&m1^m>grFvqo1BAHcw+`>;k;d!z*W+10IECnyt;>Y5sK~bFh_KOESB+Ln%6swh zPLKUX3BKugg4d(X*zHDjw@mfM@qQn;w^}D&*wtWGhsV7YDz_Scn+GK<)VYR7h*dVf z#}zcs3pf4fR)9}*^f_g%;wH_nV$_ zCTqSu*pR2Kiyx&&X15da6u9Ipp#$ z-}f6hp-3OhcFxcNfuQ#F&j;9qQ7tB8oSueiB~-&8hl`@{cXKK>ns#Q3t0-76RtpeZ z>BW@ho5>XgirR70FOigEl`s>VB^Mqkp!c|f4XpnCa^P7#S6NDOMledfFnXrIGQ*SE{fu&4%l3xpwY+oq|co52*I`t=V(N`5SH*KjcCc>5xVB4 zxVHtW0PzSqKu|)qt0?9a2C^w&Fc+z)EJ`b5%PKO86&JugKtK5RFYC|ysYB2F0FU|7 z6+iGdu1e!*PD0wcEjmQqO!2DR8A=8%H)Z{l;Egrl5ttC=$0*|E<#YE}BkqyrMm`Ytb6wibrgEU`c&sU)6nKoZaN$csA$~lT8iYv1z*T+(I=Q z(bPS`U@v#gnNPrx*|Lh9Lyl9vY!EeUVwuW zk#=nr!{tk1i{i3w>FB0roLs#tb|qM4d#mLQp=_%LOT^IQm^4_Qw4Td5Ms6p%>b8k} z){A+IHs)K-|M~Hyh1mxCBBJh->R?5qiM@V?(UvM=6|`x-Ey%r-FEq(Vdn6V>sg^uA zFdfpEGu=)%F0j%&+NKzj>pZdq$P43y&fXwL=gTG6%Z&x=$K}W_YZNckez~Qw(iG%( z!jbcN}+u6|UdTYU(t zji?bv@#oek_%{t>{3amOsLd%1dzL{SJL|YW6YMYlh-l!NLb0rPaLab}5X3D!ysnRZ zRaXy(Blyg9NHo#4htz@BgO+NwZz&Fl%xJd$k|Wg2l7A54_$9F0FMlLm>hoq0DYp}K zzJys@&>M=)a?+iPT1ss22%ed~kAbjt*%C{hrjNrLi($*LKEeTlUu`302pYoaXRf-b zQSSYEN=)-RK9(kbn0#WU=*TIZX!yNw!S&XLgw#4znNjO{4afAU>SJQII_p^ZVnLlP zT9mrT+RL_$9P%hzY!Umst}c(oN?JC6w#?1r!w(tM5<3RLl%2|W`A9ECHX3im_!@;` zDRlCp1$#R@S#H6M=m~`qXnylx&ezZwnj2YUA;uaYpgo_t{uTI9-wqKOjhQ|cOcvJ1 zy2pNf#|jB}{@zffQR{wdq0u7j>DMItP-I7Cw-kyfSQbtzL_D zSrFL0-LgLOHs!if(PzociS^P1!j>mAsox+Y#`~+r4RakeE&|>nb(5@(ru2+Bs zCyb{CW3;wsYn5uiPy4L2tNw}IJ0(4cH3tq~$p+@^b94Yt=pRlwrW*s50KXcZ#l|2t zIm&}Uh<1kRFu&`nJm;)xr4T^cx$)d*1r!BvoN8nBZDqFro~q$j(hPI?X|->J$Q@)|wG-eGpOtk5rWS^g@X#U0NuRK9raI@)LZa-mfLty+C!) zX{6uNOt9}~p}Ut3Ch%_gIb7O+M9pq!(5maCELo7e(TqyA47jdTSdG9-k_2C2kt3EQ zg(=?XAXIu7qB6N=%?WW@#~GrP75h4}pZ^yCMJ+V`mnw-6WGsE1jeC&CG>*#_bbeuz z-EqJDMOi_wiw_Yjh2T#0p`B0IOvhjnJS?-Hjr8GkzqvxQp!M?4573Nmz40YX$3f@q zZExnhEW<)2mOB}S&5H1`MFy&?Avzdaz=a5QCs8ODt*Y|_SaMhyQ&kIfLZR3RXW}|* za*l#e+|G~xMMZSM7#o6IJbRi96*~fU*kv~0e#FeQ$0)-Jhq;II*dq!__KcXH4p_aD zB0yeEDEuccIVFEoX~U<3p11E&f}YiQ*YQpBquIfdGEQE}w)II=KtakQ=5T^*z|P8L z@ceIK1|(;;w(THqd4M0(K;W&1_dSiinu*qx;>~&5QZvB?1%xz%?P7&k9 zuKdP~1pbwa?C4UrA=$fLVA+q%>P*$e(r9ZRjGvWaI?2N;sPx!HRmR;bKPMrr(SPym zNZgJ*=~p^WtP~^!YI~jcOF{zr%)%G_u{n09DG0E~^F^w6rXU2|m=SP^zT@@BL(+H% zJw(5x&BAS_pW&_Q06XKykvMJ4aQK+iYdckY!TG(d$>aIJsl9uqXP%ot!~MzTNzxRq zf3JTbp-N{D2Oqw!4!bge<|z^?2UcuPFPo$&k3OuRpXImxkCdNmwfi3~tQwEM zSxBTghWSk&TR^CL)ohfA3m@jrTP8$y_Wi6T-ZY3BI8#I)bj%QI{*<4WzMM z<~2q27MVp3>kAA*niCGG8}?ryWX3>1&M$<~V<|~)%5~#8uhrD&t|qq(cgwe$1iptd z#!v@TIjY+2h=s>GGk(z`&L|)L^?ISm-h^pI4{wBG+ip8i& z!_<`*n@fZ(LlIdfdfyQtq*K$mrN1h`i6$J{jyxA1luerRL*8xOoS+OTkK6ZIP zyamOL8gMVZqh6$VI7X69Vs|W_g;efyp^)vOf+NlFpm^pA8@@c^%uuK8S~E+DM(M@E z@b7>%g!t-5uP~4gFr}|1oZVJQn1c}fRuTiRCGB;t>1-Nx$^4{B%E3E@0MamQaBdhVGOah$x_YNsdP6?(Rz1`rzg=zXH`steA zBow7H#oi6;E|%A@zNoDHRWOJY1XmcE?DXA|#0gs4B4e2r`{%;=LoBJ6!SdKs(*ea$ zx2v_mHfJZXkcXtePx2N*fyxd$?OV;`@2y19WF@(c;VC%f^dkkr##!<(5rUPTWUnX@ zVboo~5Yd6XccMt(5ptBL{D7aLTS!ZWATlmqMuq6owGd>O zWkd!PoyDNDm8e-PLJB&8gkJK~rkU*&AH;LPlS^U~GLf2OfBoy4R7#j-uWG@?C+tk} zjl=qy9@2zY+v7ehWGX7aYtT2_Q6nr!sa(Vd5g!^sJq$yGG?%#pUTH|v$5uON1)e3o zcN%+kJ6mpYoBG>#mofyLqm3ADv^=Uw0k|GRgKv%itpFUZFEWvrbO{2gv z=d2az$JXy7Iw7Zxy9c5js=B-+TkSL#Gv=eZgUUkwS&!Wmc&bvEbOLFR)z1;Y7NkpL z`mzh#VCj>0#q4vQ9Bdyn8Np?QfG<(Xe`%LXT`~llHHo51Q&$&!L>Y^?>W4O+j+;x` zW0$~{LggDeIm@U)O`EI*B?TYjbL5i4yf=l=_t)e4!08 zaB@fKl~<1~Q%OH%c_2#QSmw`57SK z8V8#yyIBR9TUpdp>&Y}nKFW!>yRR4%o@j3+sG!r~5k8AC^I-(yZ4(E zv+k34OA?rrL~pb0RC*p&8}G*#)R}E~qR73G)@Dn|M|A$-lPzy(mG^kqa-aYaTE-XUlrt(=GhT`uFwv*76( zp+yxC)g71b(r23-<;{VHq1V7RFB8im*4OWXU+{(+rS9Mrb(^hTL1Q4ES*tZt4&k%| z;IG$j6G-bZ!zi6??Jh~#K&>g4477M2+5&^eUlqW2!d_bE1VqkTQWzrxHshcH89EyN zwu?=@J-+{Ypd0P&WCIWq^tt6q>m3O#y>UwTlwet`yGc1P7Ip4Ucel7`qFgNr@Fnh+ zvWQ58Bh(Uyu-`tH4fMoOS@lLUi0y35Ch)d;54W8-s9tf|*B@By1t+mrtd)|x?2|=c z9xDx)pjM@b5h1k-H`4S_6CQ;UM|9$hjixdW#+?a6AV-_C+Bt!MFKINl!023(ukm*; z0*Ag`U=8)6X7Nhr4z%rIlFBiJa=GGy&ccLZxS$FPqXMZ7DAG)1!z$8V7ST@t87MV2 zhg??&4-zZPT_ELTn}4XBO0wgSLiz!6qdRYlCV z=eER-TjAIp@3t`kQ*C@}3e4V?Ac|ZP!a4(|k1LGWULv+;;rbXCw8uPq*n|<1LoBEx z5aX0(et_@DuY;{s)0&!*wxCwZ*&rsH;;cd$=6RB~PWy1p5idL|@-*+d&JeI^N_?xe z-XTu&sf~9RVvER~pX*Yv;WDh&*5pXQoY*#AgqAXsH}P^6mc^xP)xV+!b|7X5jJc zd1?l|uan!*Q9z|H5X`VQ&}3Bx9d7=Qf3YqE+^aY?iH`@L40;Rx!sWJSxTHyQ!cO5P zH@gU$K0V$>_=CENYMKnV*xLkyuS`~xL$i33&BMg>9yNWj;97y=!W(ZaYeOUnoH(bZ znPGYETGS;$L%l8)5hLx2zE>FW8J>k0JuXNSnb>GN*CZl(oy-L9sj%+@Zgq0RW*P1z zCSY&2R7i2gR-^F7M%yfS4R|vv2^gAuZT_kWxsc+-p5eudUiPayo#yy}uQczTO1b-$ zFs1VcgG4{tch^iM(L+Wn>PACOL(p?j_K9>KMj%Hn0h z;BS;TrMFg8xO;P0X4KGF zGYsA0MF1jEiI`b!Y(TL^z9<%Wh6jqy3^wlzoQPeV0pMlpH&M)->W$owSxpS~QC2OMPP zvchHndw;z~eXRh{qromXiK^aBC9Sm3({(Sh8=9Ek4ish03w(Dkz-~Dx@ye6EN94+3 zs-fGoIy*;j;iUgIM&JFC-<#x6hz@Ydsl6ZTVizTOkUGKi3A6qK?|Sk7mD?Im7kU&i zQ6@UDyMFd6_ZDPG@-6p`rSY;~shW?1oym22rq0C}6j0_wtRT2ky3rE*a70^~%5=t; zk8MV7MqAPD6#0b-E$QGD73@SyqP02b+RNPjy0jy~A@VYogcyZWUJd=Uz$&>UZdXPG z;JQQ!hpoYqmhUt0Z>X6){2i88KpEu7y_^(00o4Q|(alOQ0KmqNz{}I67MwAsBf-o$ zM+)Z|i+qsu)N!nJ0_oeBlnyQ|wW})860uEkM%$JKB8WZSA1_pG7T}9id6A6^7BU85 zOimqN6)9P%8W7QV$Y%w-rI8PM|5&zHKr9epR_^fY2E0+>@biN@x# z5GMwGr$EI+g2PU^RN=PMU$MHDD zqPMonw8D+tWHivzojcefxyZrB``3jg4L1j1T;;y6BO_{aN=@r>@C3W`Y+T|rtKnKd z*4Zqe)NH?@p(yzU6mxwcFB9rTb`gkSo0B^#8D2xd7Hifk>Nl zgs~$0R<9f9%IZku3Dxb(VQ62hhs#5WEsTBtKpoaE&`V97cCUMuz74p#BWrgD-r=Sl zmz$o=4DeO6dsYB$y^!q}eR6NEDdREs_hH804BolL=bFyJ#6un`ABkx3#}y_`+TT2G z!p19p)5%1QMGg=E@j`T}9p^ye`{u<;1I=WQ06K8j+v+938N(=rMycd1st%R>y0W8Em=8mgZs;IEPe# z9+UU#bJyj0VMlB8RSq8q;QFp~cf&3}NHT5H{Rc-^As~j0e~8RwDh>3^)s}3R+cev} z)%^AIm9yO&K?oU&E8t?W#%B0lXZP2n^?Kk}pXoMGkfE{%*b%n0Pu8rGi76e+Au^#?)>c*z)=lYcXf!k1k)34K%awI(iFN&+DQz28qvYs9U+%_H zrksS)nIAc~cTy*=ZtxgepLm(o=X_o_W6P#9J=-ITwzrHX8@8g1PfTckCryZ*l!ZTq zEuSM)XMyAxg z$|@zng@(V7XJ=*5P{;L?C^uJFO>H-9Cv!5#)tS`ejHrkq>=MTxXbqvHEg}scA2{eA zr&`W;1v)ZS)O)x(J-i>b5Z~@m71XXbD)DH6q-BP-PyMgkYw(1_o6FdwYaV;IWPan< z&~xMF7CnLHTc~QM%bu|Wr>pOn^fmTXi-ouluIB2R8&G0p3w5!*7#&9wP-%MO*q6ts zb^%q#_Zug=TuOBIHt7cYXevzbYRwbI7x=tQ<(W6%GPR*aSyhk6*6u#A$|X`-mrU%_ zn_G0SJABkF1aXMh<3ZzoS|YA=*1jt?8pbo*Ztd@T#kxMg@~_lj>rj)?a%MdSW4!f% zc1MPKEC*DBTJ%dcR3qLJsEXytE>puayb%+UIkgZv=j-lVi#l@Tjy>t^R?A23rgN&Lm^2-H6TgP%JNMtyz+T>hu=f3NMEV2&zb1V2vD z-e87L0GP~c;6B6(fiVrbiHu-KgmHntuk+4fIwn~P+q5)^zGBNKGc!uF%|wIv<@?tM zWgVpjRGXX)ZI&r8UZsQo24@Bb-4s!labPxa7T)d^%0W%8H?yW9dIl$RZ3rEh=qhWV zx}3;}GWALWF4a?TNE>QzNX!fCqvY(WR&7ZxQ`y>N%2)zRAg`ytLEObV@L?61EBAT5 z1zLmT`_sk5d63S8Hb?+XRHN`25-K}{b4RxsYD*;D1Jq2?Qg!?Fx-OMMKBwLE+v}bk zCUrL*m1vx=3(@Lg0rrV0bwVUc2rm-G2pU<<#u~1SEL|+ulu%z3se6|%_OgMV*QOh1 z=P2P)sw@TtaCr_J~@=Dhd6KNz3(7jORr2*Mj3UWUWoS^rgS21R^%0oK zE}p%kd6(LFjMGg^s^Ah&D9ZQeH0=0eY)Q%lFIzWTuWCyw5~@08th$(=L=j~=>vbR# zU!NrR4#FQuphQ%SeQ^|ESE#JrfQEaZC?Q|+i~N5)s{e$u?U|Kgb^ke3jQ_Ce|Aw>e z9qpX$jO?uc3#a}+?AjJAE4wX@hb~`G`cHw?Q7aOg`mQ@0Gnl~I2PSV($o)U>|v`$sf|RMmsNMa7fs{~b$(`X$gd(2qr65vG?W9DBx@n& zZIddMeQ=uDajFC38N*SFw<*`LHi^mVaFZ-!32&^>{z@6Nu4o>{UUGsAQE#|LSV&4e zW~AFf&?8C<+H^Sq256EUvPLum;d@1<6EIiJ-m=Wix~Z zxQXYM?SSTp{kB41vI+$xvWyUM#Ph&nOmN={4$RMEJM(U^UfPn{bkS;=uXrG<_gCZ- zK%L;}!Leo;^;0)_c%NBwuyhG&(f+?r)iS!-*}3YfX*d2Tw~5e`0*Ub=UXBt2JS^h0 z*gU2qN#X9hmTA` z-v=M$pKjfrFeWMY3q;<1{=9LHl=Jcj52n;UoZd6B=sl}&0+5FDk)AYL@;`J2Q+LpH zd8+z_`EyquVQ-u~aT#Z8=XK{rgRU@bWVkm*Ds#*LX3-NR_?ko!IF+aQjkiufp23|L zw_IzxJz!#E6{2Mi+b?wUQNBXO?+>t~=GHW|uvOr;lht82AXX_)DG2#meiPgdy!(%0 zsc!TezC9wo3*1^L7Ja4|>`}YX*DmT5s9e;k5=Tkc|J&2yu(!&_cUbyadmD=#vT?{( z-V|mKRooSlcd7Ew8T!f4(^VePL^qs)2OuFu5QN?CCc|$4$YLKBLs0Gf5CMH^9*l$* z79_n+W$cT_-%6DSc1;N9#QtR!?rKxmFhxsFkV4J|C>n@qj}2mBdWVR|h$|B;3HLcY z>~kbrYKx4p#;<7=BdOHOe)Wpwp#bX?mCyl5w1~hS#vh1O)MY?CE_v}yrGKy}AJXkR z8G!XICMcans3`K5Dgw!_qok87 z;Q@f4nA6Ic9eWRwFASO)new4&wo_*j%A#3a@u5QIMTbe+@U+<_htzKpRXE^IcnB9+ z{IE_ILWI+f^B89epx4MXCbm{OGhF|B9Vt#0wn00G<)HHL3Bj{N+)^+s@)p5T!Jd_H zGH_!_LPmsMi0FN0)B#La+=-fY; z?$L*go9bK(X+&Lc6!M9`JHP0zQ)?eNLA$v2e8V^?05`ihpWg$_#SvZHVi1XeFXU|o z(^-k}#hvG|7&Z#nt)0hws<*Fg%O3q;jIPeHW76P$cisUINh*HbHQDo0Rq>GZoE?2z z*-o~gyY|Tm6!*`kaFGoJU4yjiygwS3a;It|WmIu~0#kg2`AE6eG(Hq7mgeZUww;XU zYEzVxBj~89p#m{~di@BLEQ$Bkx6`O*)j*`boH>_{&`$k@Q9Y%L(@x|ZHqkTz8RjR9 zaHWy?d+c0<9d78%hS1W33lULfmNzTiFBx6kp>N-G%A*u=GChB6Y4&nEe8qz_0xN*q z03%?#Yh{;bSv*mnY$|VdM8U@^QmA>ltNhjQ6(k*x+47!0J1FJ7#pFEdXyG$CeU&td z8;R9z;|hJ>5Ng$B2~@V85a$!c$Yscz3U4LS>Ne%V@=>U&Pi7)9YXk_oUILhsa{3oB z9w;F>pDHb^#5_;C{xl**zeb!mOSFms4LEBJyyx(uCdXkU-8|j1gMw6&8=hl8;BrzE zw%t^c+-GVyHH@Gt;tkzBBT~}b87#7cp3sCtsbc`6qIpvO5SLB%iw%WsZ>gLz7HuBF zBxg2;qTo{=I0hUsn%PCRFw9uWzfD5Z${1SS#vCBh1l>%(4MxM9(KzS0j%*(C;>0f| z-g=Qn70MtO1jY^pv2b`rO5uOw`6eQqbLDuqOovhX@{z5IRr(g*OB;tWO>!%9{^)0- zW^WZ+6WC>u9CVa_HwS!3z_>9}<0@W2u;l8)B|u^zL7P!R$cYR=_}MkM4KH^m9JF%P zkFpY*6lg9)Tgm_Za*mCBCUwGjAT#c53R3d=$vI6*Z{blvRg(};_Kepz#i*}l&!~DJ zP-q7TwpZ3P$c9R!Yr+;Xn!tdb?{kWeW*p#VJNnn4_KpalPK0e>F!$q|0;f8MFP1AV z;hz=r#Qvf1XPo{r0IW~2Cus8#=?y|cXMvGZ8dn#s~S!dqy@(r*1B`DdqlH}aEsIGm9@kQ7dkF?* zFOPKhV(xs^>3M#mo-|3^Rj)BEG(F^V2HJTgZV ziwepzE)cKTXlnnkxdd!_iN)ad22Y>1d7|PbN z65f*}6u!Yt_Ni)rGC;Wbb2GvO?E}zn^-sOS%2RZK_ZWXKGK~sqM-uo0{dZ?8o2R(} z{$?ui$BN%9tKSSQkDm*6IYl1F_YILdz%&r2IM_G#;SvOZwQaXW*)E+R^k1r?y z9e9&V8f`v8ZD4T{r%iu@^v?z!lo23krI->EOO<_^AG7@MhojUnhQmKL`LIfdM$b61 z&Wq_|{8->&Qk;dV*-&66xU_^!)Hfbfxj6W_w;dHM4xw`bC#$2QzOs~sUhz@;%w!B^ zPSQ&y4B3m7o`w?f(AeR;av*5E`L#`&X1~wk07}_*zYRMSC6FfH;EsLQclzC_b1G3g z*y-uO1ABQH6*#MQe)``&;Wr^~ttUj)KPIWNkjHbJ2(Uuc49na7$sl|Ug$>2Z>vX?R zaA6Z>IU3Vo3lWT?i4aN+v#E;Z!ub!zGS2MXAnY zW^`S4cD1b(cV;yp<*)LGMloxZRW9c)ll#VFlzm>uSXE)qKY>&{qVyBZ{t9?7at`B7 z{;Fc4eymYKf;rmFwDD}}Fd8q}cVZI~k6&zM>W3*!Gj8|Pk-@iSXjKN;<-Mfj&pCya zt_{B`Z2j;&(C6Di#n1Q5Bi2JnXXzvk3I<1Qaq$F3jsE)wM&DDZfvMd*q05UxtU=ok zP77^8;9ib;I99d5Def`ED5wT5J*FQ=iU@OlB$#tsGW~FZU;Yhr2Gp=b>&S=#Pu5Dq zs3~Golp;*`;GqO_L*n}$HOV2=6jMtsS?Pyq9}dA6{HtS#E>)1r!NE%a2>pYXA3)9r~fCfUTa_4VROR$-t-Bk`QEj= z#1wbo!wEF;vt90%q=_MRLDO`OD7T)tZ{A1S;YoOl{Js)bjEO%+9+7O($ODU5M-)>; zenljkenj_p|0K#BX&~s_G`$|N)Bi@$=GeAvamJW=P^c3vs^F9ifb`BCfrJ-PCm8ql z?>;&OS0=J9Fg}3E#nJ3hZj(HK^5nT&>*f+LkrVIbo1)nDsr7XobtutX;ADQWD`6 zvxImc(P!MwJ<`Yp-!|!YL?%U_C;%v-`C=Fo9T;qyc_)q<;w;1b&pxt5_4|21k>LL=8D>}`{K+}hXx)GCg?M1>Dhwv2l^ht%$jSc0pHoIVQC z&$E8d>mM3q=mkR$q$6E`VCH%-A4O6(vS*bIQB%|8LE&tz)mg9S#sxHu}SpU^k59f}SH z#!p+1dsD0jULC>L5xR?xF=B(%A2nb=UwTF56l4hFlirpF1%!+$TjDXu#n6WP5Z%?_nJkMUBmhx2D!1@L zc$;>>m4b^F zTbhlfLo>Bnr=?oOW_6YtM0|^1{pZUy98fhebkwm1vRn92%7b?rrH^AC==@;g-?dmH zPkx~x&^#L32*AWspmV_pt*a%YQEH7;Z1?#ug{KEVj3NhwlYa}+KD&(q}e< z70Z9-^}cd}j9R+@Xk0)c5o(_~g73oLuzEp8Ce%k%cU$s>Mw^^n1b$*=9 zOh3@kN~I7tUB$XmKR0+J3ol?VM)ywl-#9?`AY8a^Trx;KZh$iAHmGjIY{Z`6wgh!B zdQ=>&S|2hJdjVCrkq2r0!L5WJLdx>Dbcky2AP)k9<%a!{W{hg*3|ePR+od8gIdq&# z)GEWwy9@N{-kL#faS8g7a1BOBEa#KrFJWnqT*(rBf^Z1SFeo0VwP-oY4lUAwe{-m) zDQfb)0|&J-rIK{VCeWF9;`*^JQI{V#HU0juZ~@k9E4<^)49Dcm-))@n%*eWD;&ul} z?g8RJ@bPmW6}2ASG_HYD0JH*-RJrJ#xB7#mlYG-{S zOEhzfN{M%GCA6@!)nj4HZ3Ul9+bZzQQU34yItpeb?|KP!aWgfVAc&ZzxxE zO!VZ14k+7s>KnZ!XihaGM2L}(i4B3XFopIylT34z$kx8L&;mRB^q>`#CZ==Pj#LVMEbCa*x?Sl3Lt2<}8gk+EO` zFHoM7d<9oxK%PmC7Yq>5>D$7nhc`6lCbE>tD$A> zrY?RA95_Fb>^BQmS1Sp9umt%QIVHdv;gjm)%4eCzXKUKeTdT}HPS^x=gFwI$+LNt@ zCzc61mWWGU7MY4M{Rp6_eWJoA=HGEUS}KqCn^JoK06c%2C?7Y(+sYRi> zLpp$4QJXHNhTdQ529TA@TeI|{FJsDZw+LJT+W>X9V>Moy9a%4)^ z2UrbkaKN+RB~*I)=!^&HXiUwStU}JNUV!?$?%fGMlkf0mjW#}oNnvY$wzSf?oD{k8 zjx#N_?0}>Vh6pc#fqC zJfd$?VZqvsD|Hs_kaKI`~K&B&pGq^wH=KrsF2cf48opJC;VytWLh^`OXej~ zW47;s8G(N>67H+^U7DrXwmLbkeR;#d-ZB(Wm`m@Ssa<)#Q$kc{}kr#*RuV+;c;AD(*ni(O&4zieLm-!z(i#ziIERHl! zk#gs5$Da?hdgI=k>7dDY7E_-T_d>TLk)bZ?VCy*03;d7Xh<|UB!Vjpt9*omqSP4?~ zb*j4mgTjpFdw*|9t@fk#7hjh|>HVA8OI%JCEC3fH^sWD{E=aw)k7 zwhr7VEwna*51W@h(z$v~DeM+4WRx%ML=f%EGE<)DGlFwC3)rMUbq+J=kbqM_#;}tCT z!B2$96AjB3N5$Xw{JxeK$>2zNVfMc~d+D3fM%>nsx{W0Xv|7@_bR4PHh;Qb!PH#gT zsXROMtXGIgW~+_@ex{B8pm1?V%HJk=c+Mn3wYR=%HKn93<#KUbjQ<{lqfBg*CRM6s zjts{;FUm5vKr9_4QO%USk2YoWi2;(h%u{e>3v|}?v@fvrS=srdBvhc8C!6pQ7Q1{M z7At2cw_QsEsiD1{7rMhz(im`=1Qcb{;V{^DEAY+)ypNOL3i=nG1KzP|h)vTHsu>qm zoP_;~woxYYo)N@keR>m!k(Z9_<{f^5ppM~!FV@&{IfoPvyG_w#aI{pW_ zZIS=o)8Jtc4rO&h^+UcbEyyp-*?L$0prYMlz(@U9|Km&PV?2$=x6Jk4xD&?Zfl@on z;mYayzUZA++Xp_Gsy@$Q&#Vc_i9v37?$*zt-Xrt>-5WzE(=UpjPkwk)F?UZT;@-}? z7X+O21Ng%q+G+o`KW?TF`AToh3vd)x37Z=g)(QMi#a2EtQ8J_5UZnCt_-s=PO@=h7 zke)t1Q*ZZHg;FqL@hZaCwq7}7f|9GV+4&jj6p#M#RJrOH`rv6x?a~;(+LX`T1B2W6;my^UZg`ZYb9=oplJo z=1rKNn9|pp7#lcj^AqRC?4xrdC_V_3@d^~nR%~NHQQ%^381o@_S$RW56QDX6!N(K} zWyJ=JbXexaa46uCmKfg2=l`O-K*Y>Z&o-aH5wlC2iLFM=B*+G=Tb3HniWnk$wh=Zf z_$EN~>h@P zJwJe^tYe7^W?`lY>JSHJ*(AG`1?m^~^<*&_ppH3U1{)O!!@B-C-^Ky8LjiNN{Q%)u z&m863D4^~PV2T2D2*px&uz59nonoy{J(z-93qrAhYCXWTj$%D~uS2EW!AzFgYni~p z?q6>Sv(KSY=wN(;&RYD{vHS0dpmN?|B2irk(W;;+WEy_4j*yP^_XSjc6-=PAdzoN0_@!U{jk{{73!w^r^E9IVR< zD!AsY#Ji60BYPU4Z#{vL%e^npSDS^)(u0lyR|9B}<)u1m#WyxcI@LKO4cKX^Z4^8f$< diff --git a/python/lib/py4j-0.8.2.1-src.zip b/python/lib/py4j-0.8.2.1-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..5203b84d9119ea5415114939624411ba7d3daa8a GIT binary patch literal 37562 zcmaI+V~{R9xHOE8ZQHi(HMZ`tZQHhO8*6Oaw!Ox-zh}R7-XG`GK3`Hv>3k#AtH2BfiHX`GLH?61ie-V> ziuG;V0CDeXPq33i?y^aBwK?GwAdx%52f3mbTgnWq9>RSpPK)CXMx1NOGw5R^s$@WO z{a0h&$0gWUh5he2sn!_zskHjPvGDvm=vhs$ucLz6U`Pv5S{Yzd2-Ay(W2UWZSiG(i z%PY&$C+w$lDDZ0Zk06_y@6kW83+W4inEq*lg{0Mdodn(V%>nR+oStwYgIdfv;_>hG zT+1AH)?M@56v1D&$SmHn9q6#>j@ZT9+9;jDM|?~_^w6lHZc^2?#l z8k6M3Pi56CpvF_8cF>%fnM6IY3+GM>Hsjgk0wMoj0-0Ib8QNHSnL7P19J*Gw zwO$rN`pMI0UPj!V&{afE5gJgVhUyJiRUldGF>%^(Yr=Ko?8!UycX#9LAbVBT7DKZY z(4~0%m}X_&=Fbqw5UFBez=ZbU40cS4> z!Un~WD=j&2R-Q>Mdj;pjVCM#-U|kfUI47Uf9-{IpOw^nflJXd0duIVfXpfD#e@`vwA+}lLu%jbur zr|HEUjaTPmpfQ6T3aIJ9YG6=7JF^%qW(hm2j)11B91*Eve3k2zw-|IrTOU%u?XDt8 z_fpYGc9TVOe@sI|=c%~YGxyr_)CnN0I0UAKzE4_XtF>0i^-vD;bsL$II~@A3OW}p~ z_pTZ%j2-12TZq7vDYlFY(Z~?jZ;EH2b+A)CP%Q64K)T7n``#5oxpERv)lzf|Bdy>a zC6b38*gRn)EO)^c`?PF^_GSBL zhlS;avik>Ll3eyG;M~e7KG_Y)?{LE${vP}=twHv+_Rs~x6YA$T`UudE5`$=?pR)PpaFciZ3APz*(0Lq(ss;Ni_a@7GbgVp^L5n>6IYC10%HHw+ zNB8jVZ_<^ZKtMN0|8Kg-`oFrjGITT4H@3I2F*SCvw6}Bq-{h}zX^+c|@O#rIlJ0lc z?iyRtMTj8WB*byKTbeF`+yzV5IilKn;<0%jV^1LEBli1BS~({97Ax!U2;Hi$mIs+a}FWeiH8FLijssH@Kow@EUmi+Puw zUN@OBGn##5!gb-vEf&3l(-Pj*CTLKv`?y-U*|jTgnsde?-!SicAk&Ueq+7e<^rI#r zC}9irL}ARloqMEH2)S)C=!i;=IZ*;q!SKT}AvrMIH1|mwHOzz7S9T#YFZ8X_RCtUD zCYOJ=0>?OON_?|tD5^ZWBv`sL53h_rpGXk6q*6NP0UXt=3i=bs9-Ty)d*Lm^{m=`y zNp1rj1(ownC8=YKnKSkILzi{@vz`f;aT0ik(|SGstkjSfT{KT8c}xeXVNbr6Xvayb z0ic*OPpeOPBDj^KlU^7%t6)&-5%;z!JZ^1l0Be`TT%scet6IfBok8n#MlK;#M@=7v z8syu&=l2f{GW9}W1~HH?K(X>XSd5}*7(1}b2MNSBdXR_q8~GZMNCo_O5tE>pkjW1C z7GZVMb)rRDVSr|UN&%;XnnM>6J^C`DA9T;Li)Tli8bbcHW)Qz!d!YrE2797Y9?~AVLjqoAcnyuJ)%KYXUZeL zcB0u3^JHGWDpWeRmt#go5<;fdMk>xT6=5O|<$Pg2x4+^Mh8#c^G<>AoArM_;Z6>Hx z9`*=(i2{yYro~dxkc^nd<`!WJ7nI$W1qX(XE??p^%)`<__z>UK(EN8(QT;#*sjjff=O-(Y5aV-MF6WMg^fDZLU)TE%6JM)qxEskgPa{~{C6xi zsFPn9NOX@Tw!&};RG2()f7aDgFsQXhD!2OrmLf6&p+->yBPhQ`=%3w3K?*>4(LHM4 zy5<+JHztKaZ;;GF7>#rZ8_AZ2?6Uc58(%lj5uFXCh+hEm+ns9|C{Y#Y^bnSJf(P7m ztjMhuaZ=;*a`pvDC>vOn!JM~Gd78`dwVsKf@yN-v$#5DuTlNG?uX1USqFiLqps zTfX)V`%RkkTteP$vap#T3Rabs|!(BDnv zS9Io}q9x>B#C6)(g{hWWO>Zsg+g-I?gK^QOnMhByQN|Y1x&mu z^lGE5yK~H%-r7N4Nh!vW2u)@u9G8=lFHu>~Jn2#cq6jFfa2P(AwHO8J4sG%Qmhm83 zs@ejdph4{{nPk1O2}~Bg_{yyNy%68F`N^{O&;M zJzzX2K_T9w;?|>^#x+PPpjIHVYFGWU)&Q^!vTufsdW+f*v8EDJ764_thoWsC4y0xyF1c>G5o`)%XdCe%x*XMYrFW&QxOkad_G9 zCw0CFzEUrMAQ8dfxK^-xIHvKRq2ebkGPBOVukgyigQ-{w-?c(WG$ef=V6Ay68>*EZ z6FvE11FH632F7oRT2l>)krI?+5<`$|EMdJaWYb)w@_%1j=s_KSdN2x0RRe|Km4QBE zK+h&YBGqb|H>hcu2a^LoPOz8HSH#=l>Ib2z^miDcG<< zmuZ<2vm^6XPy%4w*|C&$L_W65b0L#X@Ox6#@!)5Gc=pwf~XQ@P#=f+5- zDN)AJd_Aqecvuc20yhCL}k&SIjm`$kB3`w*kyOPS^%^gFzt>J5a1f zB$bOel}buq7Mn@1{0O6|f1)EM72NSUS*ecqo6&j#0X=`4svb8a*eMn@sUQ~%@&;gz zLpwrR)0!=&h23B31yWQf^P+;kyL;l|G6N0{UzXRC)9LDh);2(AxvTIhvzuIsyTAQ0a^0cC!H;9ztdYLME6 zDu6Vmrn2XaxE$L3Q6Tib;t5bHb8(!!i*svPeI}2LI~iHu+EXz7=1>j->8_Yiabn3Z z09p-da3ru4AXa($=u80XXiUqUtVYeLS%A4%_vr+pD{y?X!5E*yrm}N5U0P{ePL5i6 z$D5W}c0|#EKt`0p!oK+=oLbVo%#|PiIRq`Jz0}u6-FYqs>2(LUsl(Rm%g8xLUJ~s{ z88NV}v}A8k)0Vfj!k;^(-jTV^q7ag7E2RJC5S?4kCIS}Uj_Tm4&H}~)k{xKoZ#h4*T5baw z8P9R-Vi8&fxQOnPlEn4st!=?NV#ujYcl_(thIw+B7jS&4n(XG9&$56M&!g9o5lEdRT#B&Fhi}OHGk(-cpOI>%d{s^U*h;voDt$ ze`AYg!dv2E_&n*Cw1@i*19zrH$X39>A0dJitGV*E!7C#skR#jWn5V_Ds8jOW>0Myx z{!6gc9yT%u;aWptU1_>f)Uwv}<{89p+k`K5K-gkyYwMj}{Fr>S(5^G7oXySk>uC3L zwEZPKJM#;(X7As6!3?Lbg0*knq>3+?N`@^W0kcD=y=z50W1--POpEATP}asjCCnzD z$CZr@?ARKh^C!qv+ObZB(A@A?gpiaN3k(pofATI(SzEpfCe<3Jv!Z33Us zBHh7gn;!b4yFdR-hSc1jbR&1IM1K}dZ=FfHS(Kj4D(yKn!V|1qS@%QMAq1?Le6w0M z?$&t9E3Uk{SASI{_fS6=bMK?AaplA?tkfd~TTKif@GYiEW*H zHIGN_RL<;!PbhO#_9V7QR=)f?6rO%XvM~D>bSp<*xo*URIkXp!S^T*chbUN_W?qrs zV3_g`y0@!iAG-^~Wtb%YU&OkE$#y^WKGe?5+dBXb35+iYDChWSZ*J#Fw@?T93`;f7kDKgp4J zLb8irq>E=N`_Er*EJ9rjiur#F^S;>!>14vl={0d1X$@Sxgy~*AI77h^i1U%RMbgS2 z0Bc{OYZnwp=xCt7n4$@NmZ(E8T zObx{;(v$Kwb{!{T?)U7VL?HNpMJ9pIJ(YOBVqxW@hGd%{ z6Mm9;UELjBU0r~#?lTF`+aXaF+7ewKZ;YL|J;D$2-qrQ(Z9e+aijJr&vZ*Kiq>HkI zf~GGadYfRRl94me)COya5$4)8K6bL7u< zbg$3*L;lMP!V>yZ6<>@V=(;yEX0)qI(rzT*3=Lij|AkM|wAynOniPs@v?(xD%>$gc zCfOt0n3LXf(NinUziKTq9iw<_`KWiiFibC%$|{xnN9n5x!pFlK1g}~@40LtvzvUyM zfQrj!6h#_9p)`@?cXb9|L{>GRddpN(RH%`k znV7mga2d@H*J;($;Qs2s7-UXkmZ)Lk3Rld1i5#1%@(!O*9bGN^Jp2_AMsy_aB+6z;>tWN)ZdiACp z;*{XDmzY?Xm{_^{z|nKp^L2Wc!AsZ&0BaO3D6C9xk5y&-TLwDn*O^{@p8% zdr^1gxb2&%>Wr??=3gkOAnejod+IyO`nM7>-R=l?uPP)k|uT<{R}k< zIN8sg=L@*(_Uh{Fi;P>42UpyGcx~bwJ4wdL_i^*{;}#*Stk4FAo%c7-eVw|@>Z54tq$n?l`uoGFjkXe%nqsi?c_v1P&4K8TT_W)k*BrYU~T!f zTxS1wNn-`Uho(y4=jK<{q~nKqMp_5ygR}mE`Pg{~ zsiX5?uleLcfF8W1yG1GZ%mk3ev;|3SQ(kY}!QN&7>rG{&WR&K6CH8q;i;9*(l;>lB z1GZqpvO%Y8Fd|_942bd9?n)5zc^jb+Cqktd@a&wVe6=ziYj1{NOq5B{Fs<*~GyYrb ze{hU8)rr&nYY^s?6$usrIbQ*3FI0dWmbRx1+P^=#=YVUBfnmts)K6uKknT)E5H5Q` zHjIX%gaO%*6#~JcDHtVhPIT$0l+cZ)78d@e7LQz@p&JtxtwVLg0#I36?7{K+HBZND zf%n_1szV@RV>MM7u~Yyo`ad@~?V?-i8}IIlS;!Iwjsp1Z`;6Y?)BZ7YFnm1gou0(;=I& zh|`jzjS6X*V4VjE?LMRze)|XmrFn!&dR{wYKHO9@K-E)rDah55a&Ze`9@mE0>r*j- zP1zp)-ND>*1{Xf~p2dYzaU5VYA*&4B@)mzQE`IFM7r%rkYxygL*@(RgX;dLgbXQJG zE~=PYV#dw57IO9j$!^2}eIC_l+e3(lC)5C{vC(FFPxKIM+KTpxkX1!BPCKk?N~w`{ zt1xnX!H_14a<02TeqK>k{-QyPEM~{Tl*5EnQcp+;7bhqj%S`^}{ zyBVR_9pow~;Rs6R!9n*#;+48zX4jfR3;a>&fkxM<0-e-fzOar)D4{HeK21h8+@MoK zs{0Lp&OH#K) zr=Am=Rd0>+)j{%8iUsA4Mp2aV?T!_piEb|+Zm;0OoTKl@N=EHdjAZ9vE$}5op2Ldo09R^VQ_JS{Z(2qGzD-O>n{wrxsM zto~ey_GM+Zs_Il4$0^B|)S<6O8rl)GWBBS7A?D4DJz+WxnpX$(Z5#kcFEP_ zYFlnjQiN2xInc{QMzC3b2y8=ISy7ePKxqFw(^H!@(4kpHn^;eUWht<-2?Dsv<~bGN zD{{2xnHpb2r+|Udo-^c1!>sx{iy0HSmU*%l-_V<5RYK7j$PjE#Di{82>o}MLKZnHO z7hzG`;D2{ZpVmY!qMT2F{xz;Sj;RaU`$Lc%rgigCrKeV@&@eD8@aZs%f^-lV+K9%s zd0@fE8FWtBl>eA(;ONc)3OZR#aa4eJOq*6_)T{w$;s2Qx;v99+zj26)lsSa zcJBYLz}NH#+kAuH9BaBzeYrBiK)tAi6{UwVz;EZDiQ~hOpLzfKcRX6EA zC5gIH@&@gJM4X7CETL}Z_sxN7jA}ReTSXv_S;mkxH45g-r#3xTSlFqfEnNw=VD57Q zrs%w&*m6NAuenomqh3(Pf5My(VWl>I#gqmn4`b@sS@GojaWNnm^A{fKtbWiqUBJ8G zLyz!dZErUwsXCePZyqyFKi<^VlRVkOHegZo3ioX$%&+}5$RMTSj^_c{hSPc7p)B|B zHWBV=s!$bdf*GoBE4DgpLQ+y?f0JRI5+52JINjMA}G?R zxVlH4yO)48gPuPp9|3teWwV&~ok4PdFJm}zpxjEv_?6~>d2&Al%LX{xzo;B|vGVK0 z7KJeTY^Tdvp&HH24uptW5VE!0DRf35gi<11zVZ{(&GPs8@LW)PtqYM08K{ z6~ZaeMx}oq#>mILHbNB3i{*+16=}J1OArIW-L zj$X1=Gpsa^0byh@<>vh!%1>$aX4T3#vE*tzGQ!9%Rtiqec)h~ukW|fq3 zJM)Bib=|Ydm3a@`sEt8eE)R;&^wDz0=z%tqpLhL_#pj9WK0*kr{AL5qPLSh%x|U0A zTX5ql{Ry2Y?sth!$u=qMxbL9lB3_9Z#Dyw!sy@Q85z0$7*b)4MZjU9du&S^cd(93c zX4JBvag6XHeG3WLDWj1&Jz2}rNn|!1Sq=(59DiNj*=ynLZx?uTLc;}r!awF8g(FH` znJT6UQ!yN`k#!XU_M5|tYS42;$?(Yg0%ys7odlJv}U4kjrWywfRT;g|W#!D%7nvohS46FN~S#2g=?Tm4SAZotOBG?mh=H^0yADt5qcGy%3)k;vhBoTY)n@5 zQ|o1wdBlbP3Iiobdzkn`Hm{ec_v3ncr-YJ)f6m>z%~Nmf=5cBwlOfC0ddwr2U?|BA zE%z|uRnM16$96<8SKOSb38*eXZjXi`*nbun1h1Uziwd?LOyU*E+70zkO1yz0@F#Ihdx@viHwH>-!h4i`rL^SR4c8)oK7 zfRqfEOGT*c^kQnsO;qZF#36)fS7?&4$|o_cQVY+NP<#BrdU7@n*!Y$=0J}SPB=hZ+ zacfQNwV!Iy0%2iqwp(PcS$$^G1!ZnQinnVM0B7vKa6vLesWt~U(482eGRaP^+9Ec@ zY{};4Y5JjVfj>z?fT;!~ZEG3t<5FM(WgB|R3HL0H{W0I%Go>baU*4)NWtEjKJYz`K z6;yk>Tc|NJCvTMCKFEzsafD)b@u{<4#)kIc?>Q%sxreK(qv=1u=a>_4lTqo-mciEI zRiMsl1)t4){d=t#PqW&)jwX#uyN;%AM6+_Z>0LWd=m@L{qZw=Gz2YBwM{^8g-ErN5 ziNjiocqCp0nQrX?8eoD#u8?i$b^uClQ4oR>7Aq+Lpe&~fDX+{VQBsKT0NWVQzr5cP zP>+`N4H@^XBY6<`qQ&Z1K~3o~ErrQ)n9g6d&z%fWY0mjB!yjkJFElACgzN6<6Z8+L zN!%wnKv{lSz_t*N)>;t;bpaq+5eq4~@1Y$|#vU5hk&f1WxBQ!5!Jl=n$O1zb@Z^>U zcs2UqOLZizn5$!|V#U%You9Tbqt(S;1B%TIt6pdYeAN!=!1+!E*;0S03{myM-kWOmR6?7FaA%ld%r`91@pi!X{?CH#>tLLp>v~8a zq`Ws5z|Z?O7B`01in3uQqC(Em6ok0|t=lL*_biwplG;5Ymjm2>sI`Z;;QUX^N{jvR z7sgiSJ(|7Qaf)<`_~chW`prKw=xW!5%Z=B){on?qrNL5AbxLTbWL@%Z(|g%5vOX=_ z75d=&V4xb&&b19w;-;XZCt#iDH;*JX6?!{6)9vAoe(+wJ&WA#cz75;1)i5C?Y8*{Rd)_RL;kNKL zv76>QDbHTdCVoywS`)Wjr%VGHT)*z5 zer`Q*UTPzDfBwq^bPdmkjh*{Z@Y+$f{Xz(^eI(Sr@eq;ZR>}Qm2us`$rIOkoBSZj< z*WXQzvpf=p?C>IJA*1T(CrZn_as73_0y+Qoj0aU$l@?5qvI#WY;yU@t<>(h0D>Jt` zM_AsKKS$QWxw zSa!}X&q|~1k>JOO6llX5tZ!Rer0)pV+Rqv2zTx%B|B1+z552134!!p}7LFxg>drK% z7mFMRvmchh#}qd*D~m^rc#7dQx9Kjh=B_BK8bsQ0eBa_QkOFm_tZ#05u%)+-!iK9b z+Q&>5QQR?c)%hksU%C|4$>ZjH&*PcxnG3Xg?zPzkHpUE(Dn4b1VluMKR@>r32f)3< zfOfts+csn0Hhx3Y^mb--ce04(Q(DOHNu*O@c0f0Phkh(@65VUTv^tT4+7_;uOupf5 zrh&SO7W!Dn5a;%o>97YEY~|kPkn5~jos*bMc}+vQA(14ROCd*gbS}r>Eg=EP8X`tz z{Rlrz>CX@eM7S#I-q&?w|CgH3?&_(31aQ)H4w*VASGO4!w8SGFIqHT+i>E0u$dMIWRV5xcpk zU$6odYYa&5GcoZyJm{8|Z$mnKCD%(X~e-{pmm`H|1gN*tU#CASX2erHdtamOFcxCCa9^U)vTai!s={2wUP7I_~w1?zcmb!8?LDTduz7KLPF1($z`sy-=eSZnR$YE{tL4 zAWiJ&dGVb)7s^Jnxv`f@d59~)AdfF<@Z*o03^R}-wCQLZlz~RIw5>X(58@9-#SQ;J z(6B9{U`3XdN6xHgpt3P*F1+B97hmi$B=wjBrnHq)lcUvA5$_a$`?XTWBzJTJm-V)! zCA4qt>-bB5luv+pBuPk_{4KjDQHB%-v%TK?C1nVIYU#IjYKhH=$Uk0t)}P_IZB z8MBZ7TL--mkYa8)i8r0u(D&!BM8z-7^q^gEcChJH?*R^-#Yxwet%3W4LI^Wl1g?%; zSdiFt%MxNMwL;9*K#`lT@1cKS2j{yWXr14W8(eq6x=xOcsJcCJ+Rk++sTN{p`Li#5l3;i}F zc46h~<+TMv-|6hS%!N^q>w#ZvnB+XoH_rsfNoH~EwaV5?RQ0FyfJ0c9^KDN{^_L90 zN5XJWc^w!d98csG1OLQ_e+bAdhswL%6P+$NeViz~+*R55OlJv+; z3c)L?z=o8&5t)EU!t0Xu{Cerc*Qa^f-}&d8o2S+m^}{x0u+i6;X)D@SjlDk$%Gf;V zP(VoV=jSZja6Zk@8F-_GW7hO_YB$t|l7K$Dx>ZUXFfCHo;V|pD9w-@OZJjRuAEKzIqQ(r@VeIPq8gMv*Ik)Wh7 z(#~lZT(}g;h230hpb*hEep;Lz_XPbE`zg_Od8Tf~GVc#+A;+jBd_)~__?6IQ;Z5uU zhzg6AwfEt9(fy`_*+_32yH@z)4pu+m0AYQ(DPo6S{pV5dVw!?lY2{Np{^1>FM)Krkcr3 zBL92xtAA^UA?TKD(xJo~!o`tMEStA)^+P!}u{KiFyy9ueLGO*gDFBS)Y!lg(`*HpW zjj~SG-WKESv)cEl;taE-TZ1m-NGp_;Lffs?h2EM)nx~9NhnxRX3J;Ak;(BDaL~;sd zeYPvle?qV3RFJogKu)s7F*2_v379KM2EBJEbq5Me#H686a&kBfhJrZ_b?lNh=gJ{A zRnWX=9x0G5*t5_z(uyglZ0Du$8nCD-fjM|Vz#ni-l zkge!1M6<1|Fg+%EMi;i(Y9Oa=a>Vp)> zv`e1JJ4k`BoLm+`M971QuvX{<4W>(47QCZSKRdOphsGBWq8j&u8F?Ts`nq-&G&!9&;Umnhp6 za@006ELcU-WocQ&P%APjc2nh-KSF4{-h4}ta zhg^5w7!p@SX2j6wiuxSHARQc=2zHAVbM*u;f%LTl~x%juHMmzvql%2 z>bf)tYj&ZMuA|IwBTIhsH4zlcYeN%BA(e`*mv`EvmDNxxl{b`Rv9jh~zF}b&A6a6b zxminRAYiaXIRW&F1a>YiS7SxSE+sYyO>7K8MT6DMWWrgsAAsezm?6!6!}BANESJ$* zyC1W5S%neZ8}r08iE#+0I7`ESn-zKIgIf=g6QY28U4zxSRi+KF*V#qb_If#?(~diY z?i)^Zh9*(ke#XbdwT2C){LOXdIj;^Q{*4R!ID?RPiRZ;v-US-%S$6R9^JE2i|Nh#! zvtsMsB$}}c3aY?0Y0XLaoSpc{dS^{FZ)+zfkyK|o|HOrh-})~_3N+Lk59IJ_#s-0C zeMed;F90pIHg-9k*0Kd=U-XH^FQu9G1O*BLKZ4M8%P-8}@e@f?HB)ci?uIT~L7Zcx z^zlZrowk)VK8(!O&qHZhva#fX#qxiLAB(4zoTYgGR+R3 zm#9L9wJVSK!&r3G(a$$ERiTjz^p?0tyOb-A_s2C3X}^6g2c(&+>Z%aq0G@}Xo7(5n z0~tutphlHMWS)qcopvHSW$#+56e=wC-#8kQXA4c}hNILLF{HQrai`?Hor%2Xdqt4i zx{_(Y_i~bD?Q!*+q8_~4S}|RPS*1Uz({yua8Eukw>1Nk{$u^58LFORZ#j=a6GPd2} zJODmMG*TOQq^anNqs?Jf(@_`|Br7hrUF4(-99s0526F~>wVLk}n;+wyf=p3@ZANHz zyzo{d8@|adMSV;Id!*#?blaFzOC@y~XsT?d|N68z?3yX65}xBvNh<0$P*(|PaI+ZD zR?gunyDf3PKraqKY#4OvjvLG(7Zkt{}nRwQDc_tPdH2E4udK{kjPBAhq8W|$)< zH>jT~Y@QDKiB%h`nh!I$TYj;RR!gm?7ATJ?`C_cB9lIf-1Y02%)zK(U!%}#``$+Hp z<-1N6tNz3dT_xI9s?5}57n>&?I*R1FGut~l>t(I>PHTz=&%md852IjA={XKZ;CYm4 zpoMJ9Zr{fpfdHg>(oP*I&R>6y$kE!XCtuxtYt&b9L}6a}@~JLj z)Ztzy>Al*=mySeLrO{`Fw)KV}Ei2*zZS}#Vnjal8TtGV{ZbJOgg&XYDzs{MjDrJ%u z8}VAA9R3trLk)ImKFsnZYX=YwO>ZZyb+}L9hT3WN^iPpw_kBLIxJF(oT0c6;wWsxH z*B`T^E0xRB>a`KNEUu_Y8d*+8McnPvoJb~d!Lf#S5bb0-xL@z@v z%u(Ur#K{>;VqH8D4pAH3ulfgSrb^|#_}t}Z7a_6rW+woG=xXv}FjCf(O2d$$QkP`W zCA^gNsWIZ@!Ei069mTVdMS4|#L@R|JLhWf?FiN$%vY$-By{Q|LCjK32t2z?iGX)`Q zMu;gXum5Z3%}E1CrPx}$ zYC$igCetmr0*{>cVLq1Te2t-z<(WoRMmTv^K?Z}xBtGHvHbGWhIOTv#f;%^pa-3rH5!K_ix2W8 z^gU_L(Pgipm1$Hw9%AeigEk&w+D8@rDFllB@@Nf9=l`1ecy*nWuV}WgKmWk%hR4G{ zH!QQw49}9^9M2Ql@W$&36`*x-t&_gt2|?>q{8Ez7;nPhDFKw9?u);s>kRkzi+t!0z z;QKaq2CmG+7_Kt`vh7?oZjj0oXc(x2m>2%Ju$cBqG!lcxK!V|vELZ!? z6azw+FvTpn+5Sd`tz-&6UEf`wAkd%-sH$r1w&|#IAQ~(=aBl7MPb&NwNirdXDwx+A zng182RXd3-&?hcT3Q6HFE%f~5gF4_RHqZ58}xGUC{3fSeF)x8Hx2pcx7k;d7{TEvr@Y$7 zsV;S4qB}`=PQM5XFle)f`10!A{d}(9j*Iz!V)Ff{!zEcVA}iyU_5|6lG_6OS&XijH zunclSDklpPW-$1vym%=h>a1Eh8q~(;_sa(FdRuY6W5ori;@RNk6x>8h(zRJwh6}Ob ziZr9aY(+UEVydDk?}mO-P(YsBO>cZ4o=2o;)EElQ~y508~8*FqnT z2ib4Qk2i}B1SnH*Mr(oSx zL8lWVs`E;%Jj$pPKs?xj(#mTQx}erMyqI*{WGVO9xh1r4(6JBWswK`ZC}&-{L)+aR zBVmi*982J50Z%hFz*|i^?a5c8DxAE@k)d{5yCg7ZcdhOjA;fiJ%`x-Ia^WVd#$3ObWo+4e z^Yah2lVx(jcxm%m6w6*M@#>RbXG+mD{PB5T78}woARsB1Kc(5?4rF<1TI{mXhYgCY zPMnctyUTQ~SYm!bKwELrt+JD#H?<&#R(n?o6mp{SdA1(%{k=NRmWzoy z74UMO1KOl(@=HE`s^Zm7G)v#o4vyl(Lan_kGrW|H9$+?`24Q&*N=z^R?1T<*X`v>;K?y>UgY4vTb3M5kzjp9k2Ee3Xp6kIB-7MNFrYe zUK86cACb#88bYG1hFE`YfZ+n#@&fG%g~E|HHc0apV>pOl){Np~#~ud%G;j}%P^^o$ zH}zP3@i%D**iV%-r#he75L;CQvW%!)zWn#4Cd@UrPB&A@vu%?me^~Xm$CFb@0F829 z-`aj3$sM|RAQNnTk>vipmhoo^bza#h5N1{w z8d)nbbtYO#S+jtuxsuE2tpqM=|Bn2l=)(e5zNJoMm?i>mzh{k#D8BS|0?0m#jL&|k zqFlJ-z$70%GhH?vCvGk;GN6LvzG#NVsQ%cu_j3J$qt> zl90SoMC#0utlr7nD%H%Mo`~$_tilrVtTcB%%q*;27Tbr-$G2g18lCEaTM}zH(fkuf z@R{@dMiQx!>sI>Y$^YCrM(oH9OiW2z&9T^Wy2;eOX?mMxq1bHW;3MUCv~ z#E!S4lS^qnbZK2cqr0h<-eI`9>u5JVp3|S9>Oxs zkr&6dwiK%~M|p!&3Kp%gQ{eAp(|uR#6qfPfQ?e z;kXkJSxLw<%jOfLI7%$6`^ft ze^R6xstgY@9A+AD2@^bnW$9P4&Zb|S^hH7D$z_>)W$j`@gJPN6#<0kD*U_r!~hl42x{dkLwRf~b~WD?DN zQoQiEBRW%qaV3E%<@k^;0;t9vNF0l9Bx&j&XY3)X2C4N#nwJ%KzZXs9b9H`Zamdd? z5~JJ(Jv5YkmLzK-<}H&-l|68pnK7z;<7vYYi?>PF(N>9xs&JDmV+n7p(7p;8w9aTA z#vXEl3{h{m23SZ+J!Yia0?tR5p)Zr8>SNp z2r0t?QeWQ5`URLJ{`Uo~PfxHvAV4-oHskg(`2#`#l4B0^$AuePH?zN7KS?$UKuP%1 zA!Y$X6|Jtd#nrdWq8(*l={H`{6!_X8;%(|gj^zL&;Wv3H$bHLOWsYb7Z_O+DhJSs6 z`-0AG_?6#CN?p&3? zf7~ojb8CNeS`o*s+Syi1y;iufr+eJB3hv?C7`d!XVzkFqF?(oy=9AJ{M zlS-D`^jmC2a9^RPZ?$+PN!)hp0vMFuNb}k9XJ=X5<7}+h*j3?|+rZv}fv&E#jC6nN zbzcS!qM`gclZg@va<$uDubeN|jqsls>vXZB*KR8Py|$7yFH`>mVwu^Vv}3J&W&xYM z1B_}g=oa!xoqtsirM?F~z(3WpHEv8&<`;;(^Zaq+94Y7J4<1aZeK56aV$pM2?gStW z=OaB~w&;K045sd&>+)3j1H*e)7GZCkJboExYv*<6MT4#|Wn{QJN-A^20A|r0CHR^| z5jdHv`GvPeK%T*!8@E(ryVY-EV-=!h58Een^QUZ?ir*h#QO&J!a(*-4Z9A*ou3xM| zo>CC;jlYE|7wno4&WZigD%{nkpnj5;oFIjq4Nx=?(;ged z!t@Rij}ccUSQ74IYRKnMw!{`0VU=IgDn?SNhyCgm%R>RyDJr2IkZ1vcJ&ZpPsqpW~ z)0pJ>7nT10f_zAq??eFBmzbb*7NMfZTdD{oyN;4hj+EP6aGS6&=t9=C=a(AzGw_jK zy}cCzdaGRV5Lg$;{Shqcel;M8n@C$Cn8{S^M9{Db@7V}Z`%Y{MIx9Fua7IUW1dq<+ zah5;QPMaR_a7vjE1qS+(AIUkIBq@dPG>B3%kXDLxUyzww5}|hDxw0u#hLJKsnt}$= zYOJ}|YWFzkN#eJ;zrq6mK@q2wGduP!BwrXbGcx5v<4lLn0+dCQy5d8*%!>|_wBbpU z%WqP@2~^>LJK;fGWbuPqSqKqMJI*7VNq`CABbuQjANS=f5*-z@tT2Tur| z?c$b#VUf28mJ0T)gcE`5iz1f-?@^wW5mQvsX9AFtChEvHbZEFlkA^QW*Zfm(^Y4ZN z_3AFc@bCyh`BNk+M3C?+b6p{_YVn+0T$yz?;y!26fT0910lwivr@=9MKplL6JGp}F zO#yh!Fd|W8D@JGj!E}#4WZYC|T1dm{f+LVm{9Sp4cO6=L$O+m-HD~L_NddUoMS1)l zU@ng6;uZr)416JP)0oamj4$pyk43Oiz;10k=94|Wt(*4f`=fNVj_nf$_d9bAfJjpD zYp%(jm#T^ftY_@#TgrB_`CTtd> zT2uH?tXP^OU)pvuo-2(}PL7}>riKc{{ONVWP_iW6S6@yenw9;L{&MDAIzroZ=SFpu zE>7E#v)Dva1Z0>WFv1l^=I^ny5q7wt)9XTu4=zMRnOWYfbU$QtwFkbvQz?&9$jS8l zu_f6{ZSds}(g>^oZvBjaZLSrao~7|by|SslvLo{Uyds5~r@P8u4P8Of@t7^``m=*l z-djw}p^g+hlhap9qqvb+-8L-K_YR^~UKT@T>j`l_Q4C*(ys7-EKw8 z;*7>Q!*yiykQXO@G4a-mG%8mH!5}bp$d84?D^v=e6*GP#iw&@EXn}CRbqKA%JP}%Cd6n3KA^y#_>Zp$g({iNb zGf2G_?@>la4b1%H*`pfrjxDs9MYE?O*6B6oZ7<*O+8=`AJu69`4slDHd(Rtv->3)@ z;AkH-9)ShK+cpO7JV2)dj|1ig1EpV9GtuKza*hqe%!=7X<ob2)f7Gj1< z3W;j6-Qg?`r92L#tObT9QR=*2n9EDeBng$`HREo}M3`cTX!~{4)Q#CSEkDh|SMdrs zH~EU77E0bGRTZ=*7z&mSErvJ+u#vas$nloDgjp2zgp}!k&#@grtX=HxdAp$qu zpyJh7nAn(RkZb5-ANP;?&syraRJ7zHb20&;Z#jZ-+kf_-`qxli5nU@6N8rPs^TK`I z=Y7g_RYwVcnauyF>t-<4?7TZ-@ehWRl69m|&!_?q;CG@UJ2)r4ea9ZIg5}-DALH}t zXMcGAIev_y#|@ACEs8}2Wf>QU*JL!gchFP}Hnqs&b{R7DgYA*e7vMeL9mm)E`9R!o zT7cfLdi;$)`3VeVYgqyBNfHWQ?v0*r6za zG^r_w?@k?)iQ2(VPx~I&OG~K0Sv7N0F69$%Lf%?Wh^o9Msj`qqvz!R9LRAb) zTYbqOeDwwOMapY*KTvRC<78$QNkY=cH@Qv%a4=l-T&V-mC2rU^Y74}Y06w}9B3czp ze7?e|XmkgQ9IJ{_9m&k-y6o&~nqEx3fG)8?Onc2MNR10BsdNIWp~?aGPbi!5lV9? z6sS%r4`c~BB$0CyOlu9r~d~6AliTb z;{O*15bwYLhM|YEiJp_QqlK-R-aohVFSgAg%9FO63=%U9-RZAa=;y0gND3)*%khTWcg zUYE=QrZR*@Q&bf+l9UexEU6$o9EDVZ>oq}p>9x+X!Lc2z$Dm-e5*m8f1-MrNw*GgB z7qOSbr?#6*02-P?WQ-L;?Lai+DD5IReV?;ikkwD+J(1q1czG14k`z+oU65RE_q+(B zqXMA|Hq?^Sg>Zoc3VYdYHITiDCfur1=&l1qy`JE3rH`Qm!_pi@E|^@(>x}Dl4@^|U z2fMVvi+DU}l{PmZy~ZGxqq@)=Kp(9lrVUX(2~;-CbHdXoIf^UR&KG*tqXfet=aB%<1G99Q5iRFps*=%q={afl?`# z@s5Yw34ElTOWnCc#QlJ?vzfenu=BMQkiOnSzrB4|jIuN0KSt97!-n)ip+PeyHu!5x zab|)W$rVkR4mNg}1MLXZ{?v?z5cg6So+F^&kGBuDmjzq zk%JIPl4 zeR2NgF1J%7?$4mjP(3a|?rhQoXln^$6|#V0RJgQEkm$Ls`%)jQnIIC@)Z*vnuTFlt zQpl)@MXZh`m(jQD>JdnnJM18aAj)gkoCTdEuZ2z%Whm;JTCHM(@uk6fl8W0{w(qe+ z(RG(_=yg%{tA6(?D&`Ca`YpkTI*qx{HJUZj_tP6tqxn8p04-o6%0Ju0pO?~UZ~}pqj?uOsOYnRJFMoxZ@EoitE%Bm#)G`+w6PT=Oj!7BDN#H$9X&ppeBySyfl zE7O2V!+&O321j{77R|MxE)|xr5WN5D~!+1Yd5$0&`> zUpSIx_*mYkSjpgp^rVcvJrV_-}M}5Yx@qUzX6OJ|e_t$Ge#^mC5H_sry z<&iRfkfWO=f@{-&RlqF_A1UT>9Imp|HBQF{s34fd+8|N?cWox z5eT)s$IIjPeB>|1ci))!Ec&;5-hiVf?wf$fviRw4c$!{;Yy)O|ExI@I-ZKhsRA75n zI9{_HS2?Cb#V?10z7q11s((9Zp?a{%cq&t^gEKWmhkk(n$gAPwg9P+1xDN#JKgkRC zzsu{t@O=i3js_n8PG2!9o3a6ra(es2!(+ zJ%Qo&W24uwh|t6b2c}ZOg#^#3RLTJ)B@oE!*I0wd< z2a_ua5buGCUCwe^wZxNvpz@Ibfmu)YI+rkcPm5w>y|1etA{j1UlHx zN>AHx4ShC<-<`(KjXORWnOZ{0o_oIM3MAg?_bVfaopiZ0@kk+QD$}rLC#0pJGiAny zMacEYb-h-#xS%pSh7Ha4w{a7i$S!QOMtV0x&@wfl9j=IAk5Kq^*&#Hr$rG{CfBsj^vU z((Z{v9;z46%&jvHh}3zl%OWnR_Hr3IbYuw!m5wpIS%x#555iIf5d^V#bvRVxEw!w8 z>PJ_mDpe7HK#`|VWiHRIeHtnrfh%)Bpy)}Ar>}+4tQu|6180hcgB9*He1DRHHFiVz zK&L39UIIW1BIwbpG;*CI#3AgeTh=b&vgZN2W`pX3FQVD4JL?E_-FWi}axMaRIm7jc zWs(y4r5?>#IZsw_KYa`d$#woiqK%h^(4T7RK%?B#doj&-dVca){vbQuaCr6pY9fVv z%s!?-ov&v84QZZ!?t|R%=BUTSJbtSFMd-pg8AnM3PnGP0Z>$n{EeWz?U2Qt-ceLWd z2{OIPvmhp_DZ$0MFfi(ZTDpghzu)+xy}1FQ>JBMJv!rI@}geIJv`Oc=CtAjh7 z^o;(Gv;-DW+W`Jb>k0fnNsHiLNXy8;+S<^-$m-vz>smug4x1Cj_ivTyEZ*2^P^bEs)j^Vez{_1WgLI)NGKN_o_?UxfeKe>8 zj&M*sbcCU(&JE&dX(Ggxm`!jg;;WHFc5p$0Jy<=x1x1@wL3o45h;M2Y%R;UKkBkKp zod>KvZ;)>RI+LRWa3T`aRU)OU9rJ*^JQ?o89$PAZF^NJbxpr53y$)Wmx+x7%BPXM) z-5sv$iGH6I(h? zlR~iBP(+Q<`MuYQISUVyP(pB&Kky#GlR;8AH&N9}Cs?4F1=;3skDZP!l}eUP)B6VO zej$n8W0{#!kd%q&3&lGVgivmWy5bC5ttR!-t8d-+B7(N4W5PkYe~2pZnlG4t)Sf}< zCiDuJ(^y}C*`5ex=4kvLgR|l`W+tt(gG0cJ?EL5=kZN?Pd zP)@gI0XJW`OPhSq&o9VjI{0FUe*hxlV0hvxFdE*TLfe=?t74DuhZdJjg+Dd0=h-?e z3OB#SX2d1S@VJCp%XY;qJ=!y~wcby&A%*~#r%qq7L%vWJUN&{@N*`T2H>ceo`Mq92 zs&`~#3y`t_#WRQqq(${>CQW_b*lHbxuAFYFt-udIDVZX5^Rw{5EiuFr!8@VH`qphx zwF)$mi`Po-w)+M5LIUI?Mds6$`tm4P4KDE4nGb;~2(5`GU>rmMHN#stEav#(`zvU(J09#2#$jQJu7VOA*h=Z9XuWVnK&GRg@w5oDeI^% z?kKRJeyx=`3Aobk#vNHGsD@fb0-H0VV$q%II}HiQ!z&dK|ry)(A6wlqm}FZQ+1DT^gmDaQLP6QvLX)_3E_&gQ>p2Hge8U_goOJ660+Ksvbf@xxk%6l4&r+F z5_>5zHDX1F6KiVmwX$Bj%cL}j`c?mwk}OC_Kv@k^u!j5cp;T4#s6>&{2`_r<*>@%b z2=Z<+Q@Y=6d6y^-nNC2W!F0=favMi1!~ru4W)7zwSXh$a1Qt2~?R-h!P%ffPaS++M{*zg>PbCiBwR9t1U}C&7kk)J}AmGwxfG8HP!@5c#5)p#c- z^fj&$C2}0A(nggN4q3yaCx_g3`d!ohxsSnyAUne9)0{uQ4yN>q0E-1F4gdwja^vpF$6sMSH zChKowp^btNQM{IMKS7#0yzO4rc>e6XOho^F;7xVDUvg>yY#`%@Y2_ScAxcmQ-{q;= zj9ZN*K_L@Pua}=%#*7z~@2h{(M=J^QFBixyxY^_8gLmlPn@m%dBcm96bFu4>?$?a_ z%d=H#st~63qwW^3RBx`?z&xPmR--oeqg?2@(dAWLXi#RNm?^!K^I7-$31ED`Rt?h~ zqYY}9U(9#b?{+_$O1{OKcoa#2$jwEAX4A=VZ;4tawQ(g83l^y-5Z5(={6S1MMr|%P z>0e_ZMW7X4Mg%)i@Jk0~qhkNQkElWth@S!c)})mpL`lZ6TCD=6_oB3VVU7V)xs}F* zThdei@M${d04$kuU_&-&XD(CGe}n$Vp|`p|e1-qI^8)UFa%lX2;n4pR+V*b^n&;AW zb7cPK9R>Rwyf0Z0RK5ZCcVP|{6Y69$U9NcNIirt#>YidSED#NVr0GuC>^%C%-*j-_ zfn<1MN=fs9_>rM4H#0Z0*C5^cheI4k|I)Kv+}Owb`VO`Axq0RaCoE>ndNRl_k7*pk zo7eE>-XEpXN41W~Qvr8jGM4DmDWfP>vGkDsz_k}^)LKE)%^@9{AOrNz!y#L>#^F_Sbxo`t~H&|fWZQYVA;G{5h);CVIVH3 zY&c}t>lXsCOfB?c9a7YeoZuxUGh@oc5`# z$T(;)vas1hxPJTy$BXulk*{O(nrIt%E#%EkpaApx2c{dym0n0ah z?%y-^_KybkMtXYN(PNDDa35d3W-J)7OK6o&UvWXLFPlm5`&)`#k&F`c)oM}Dmw5u&70Tvx)!1Lr%(8iEn&akABUcs>Bd-0n1LR`^p(u; z))1j@)yv%CIBZYYLWOwce%gZCxOp2IPigk$bd}9dbnpT42}>Iv%Tq&WA<3&;$7_4P zC3W=@CMyv+J3HP@7>Ej~A6g&XZpY51{k3Ui_AREtzc1{e0n$)^$-q141FMv=_IU`j`$NZI?R zkESwW-tp<@!7#@}aAi*2o1-KkLO4>9fUEQF-3vBsudk2e3DAVfbHb_%e-h*fNwbZ8 zgiovDWq=aiBD4U?h7oi3PlP1{OKvfXK)Wu6G8~PxREtezJr}VN^oSsKn~ZL7!8F;D zF#usYbHVO&A}FF{f$t|=rl+g^5d09x;qyPhdgzBIEQH3p#f?*F*Gt>)3qh2QwSvG_ z5cE3&gFKe@8P(@~D_ipWD#<^fF(UDPL;gv7?}LenCx8ZzyU~HD+%UY78||qWZVJ-B zQ=bnRlHD(|Ql-3+n`@V^O{IA(w#tw~kd=F=?yombeq*`-bh@y5S_^D(g*2UYm&xWF zRK7t5N9cc49~9;BLAcgG)IL0ZZF8`c?|W? zWr!W!-wm1r=8YXiz*hTpc-As|lOO+(k+W!XZ-AOmgY!K*%UT*bBq&6NE3Y7$-AD)L z?`=-uOh9Q8yG$iIi+rBIOu?xU8ev;%h>hT9B1yTRhFZjF#E18RdAAnkhoJL#j8O$c)=N);wtcwUptT!0`U2O zCh%;4_>T^g0i2Eif|EfGAmG#>ClCk+6rru9HQG=}QHu_G-)z;OnM&qQ)z%S%muT!Q zo;ma=G?6L}N;{(J%>1$e>d9w7`XR1S?IX$hYGj$`_|0Sh>v1AcN8(?p^H>>5ZHzwM)rmgUq_^HEmb{LxDST>6CEHaoAkfQWPGpC%WMjE1v2uodJ40Vt|jE=G085;Ys6lyqlkb5V$7?s z^_Ov=@<&{m4$ipL3O%TrWVFW9#)a(CTVvA}=eiC*Qra2C$y$H$$#x%eW=4j~>xW&M4G8m~0t}Yaw4X~JL z=Ab~Z-6)yuL|!rxSl=G4yBM*-!tZLv=z#g_*~)d8*KiaTU~dF9kY)#tm(hClwgdd; ziG89*H_VQ(p#y$~g{aCn)~LFl0L97qzx3iT{Ht9{oLv5!6Re7>58y|ik9oZY3K(u# z#98t54;uiP`i$DFk`5X^M1#N}^$(TUk>4fr?I6PA`BVsrq)`1cB%`B_5`_rVe&De5 zZdy!M60e2RKwgV5^&p@sl$FnH(8!`Hz!*(DS$7r!{?@%_ElfYj)`uBl78 z0eZz)#&s^Ko~xoEJ{3C7p}t%S+ag=4+!j9*xdjp#Ik#3FdA@H(oB}UWqK`_u5{V>8Pu1RQ=76ekqagc($?9>Z0;y`Tx|Uw$jcK((QYU0 zEy--@Mqi54f0j<2`@QwO*kO5p_sf-+UH4~yGECL;OOZx zLHS@v_E%6|i0nKnKB?Q+fw->4TIr~3IVgQ`=kZSDVr(39(%Xmmp*{O$1vGdRfwrD} z?oI>f;~}57y88fPZqp7|D6%^qUI8_Ftb_FLcVH_2g+FmB(06Z`WgTO>xQx{7FxW|D z>h&aIRJoU|Bv^cElx*{3G31{T7Gzw1-}%%JQM>*^CQqBiQUYE^%pJ(b<|a(A&!|_+ zA@v)#Ew-gPucdig`9-_Bl&qLGDK)zKFX?0;6(?&%wExYK*sJkoOn01?iguoOxZOF0 zntW7zRmv6X4MDMhd^7%Aui9e=^!vJ9kIVW`TcV0a){s$yjM7pNPLl~<7mxe0pTk`R zFPxsX0wOj8=as^z%Z``TP$fIqCHvX*TNt+3IKS^S*C<#R>awy6-P~I;HvI;^|{7ehXD!4^}Lk z0F~hSqvwXFWFrPCI;4mwb4vd%zY^Ic{H=H8RWe<$yf(+=W(|qW-k+4_hf1vy_0qt7 z@X$AY-gHReNV)J!9aIKYrEq^9XkpJTbwq6}VyHkZs z1mnr5_DDT4SlQ~&;i}~Y!rgn$7vn@Qcy#er#S+fz^upy0E@Ox3ms16?Iv;ikOJ#HE_8`(r(IT&WaB+cb!VVllg$BvU)u5Gn<5h=3ES?j zF0`1lyUuZ_*lKeh3${tS%^H?^vHvh>VUpXnB~i+J&cSzo0mRubESq;D>-FA4W~X0| z`~8p{%3LgeG8L4Bpo#fx=YTg16xkW~0v;pL7*`C)QS=yAmS<5NR%>WODPb6$7CS$t%ahdz4lg z+bHD9X%}dM$&+A_qUrREaeBf5=O3$W^W^yYRQG4(V=V(7X%2(#`3%y?K7w~3It-Ty zNDK+Pq%c(xRkfKgur%X6sP7*P!H?@}q}PT(WAGl+licj!WKCDqyJU4dVUs;1$T_RL z!r`mexY^Y1swj;B3}m4Zd#+BS*L?;dJv{ha__lwz`O5nczCaf9F8U5svA)9|ZP=s0 z4Nsp>LC<}JjZ;*U%I~3=Ny4^hw73g?KWPqQbK{eGwRmvZ3X6iL>!fdZSm$L!f8Xb} z=J)-8Q)n8-`b8atbI7DGz6DmjY?-4pR^G(lc!r^+I9tx|KAdFE?@_5Y9!0QZD)2N4 zPShUY0G1;B40&x?s&c@>mxVnP!p>-2rO2yrZ!_?rly7^0h^8H4znvf;ARN7K8T?UcV|q#T(Q^1}oCm-FRgduK<^hK7L$*vHgeYN>#3 z9g_&GWT(Rzi#)og!8y?C2)D6gggZb!1bhSAmizA9(E}eXC3a0uXPvI7`boQA7nJ>f z2*W|O!NB%cs;)6~E!#`>@C>44Hm*O~P-tSChvPKcM~f|1AIX0#SV_dL~|G@&H`AaCeHdO$-hR3dm5S+5{s97KdrJpZ{1P$@-on3r}{7 zb!>m&KC;9{RhI8k+4PH>9ruNDn=K6*g6+ab+4C9i5P=cXv5fKGsa!c_@H*l9dNOdT zeXP?OYMMw3AC8Uh+WI1?7R>L4imr6-2YR;OxL9(hJ;j92-19XjOW`UT6GzxT41y7T zl@b`9?95e(t1s@zz~l|XWv^gU2lX3Ncf(#LENKsXlNqcou41V|?t^(v~UyVt#i4xVpE>usFb8CHY= z(uM&QVe8McW=&dbS^;tMOxh<9_8zGi3U~g14a{bMud9Z9`n=1ER+|-NE#T{$fl|GY zH=My=HnMaPm={TJ2Wzi7W`Gev-7j6lHh6z}Qkv2oIe_$feL+iC_YR+7yxYE=*)jsw zS>G-XE}*7n$oEbbpT~{4K+*Bt9oRBv%3h5pzDjfDWXY10@W#LootsoP&UlH8G6WP@n5F+%q^ zh~7ZFnq-&?gGx#+7BtR5LMbAq7K6vk^k5S@i&Dfaa&)}?;(#{jT?Is{l}y#TWRHdB zZcuz&u+E@(3Djf=OOsx6XD58;T4wXjI$4x2I7yu+Y!^0U$+2xkrOeK&F{h?7rpbAo z5g$~9@X-Q`9*t>a5mV?O)-;DQux0}x6z+>62qC~qm8UIi8|({~jEn+mdq!f_grv=h zNCwnVwFKp=PJz7Gpz!C73Ow^ce+p^8j zeOz&YF3li6zDES|AX(Gq%8EV+C53KE&vUG~iB>C>>El6YoaBEj;kEvdSy0EBy|@_a zon9|P-9#bX+!pWmP#pM{W}aBh)f!S{W|FOzRzC`6;4oJ^l?a*4l+9#2AUN(#t5Lx^~mdG_M{O3_-apD@!LtwWDJy85>QUJ9{Dh5;Or0j+ENNBvSfQ6|8Gt$xx-U z7HVkfzF1yo=T-Ra4K8DI*&vfc4k!MIf}S&}1hwxWNvQG^ouD8=Z^n(%U86B5$HK;6 zSJtR>gePmp=4gUswCq)RIjmY{Xj($g2fVvtDq5E~U!_km@@PErpKG?m~3D zS+KRVgd`U>T}4F3$|_pLhoy>`l!-O9R!3*1Xk>I0J)fs(hNPr4zq=0W1ND#7(*44R!k-LzuOxElY$@In(K^2hFV=NQdooZaW9`pb z4u+Uadz&z0v>y$0d$cDLcja-QcA#Q&^1fhLFC-&_Sk-IS%R{%57+T0xU!iPj1wO*6 zo^729yf_S?UA*v=;b^?+|G zXs$j1?FZ(4>~Z8}G>+r#MH1#X1Vlg!m^aAA05Bih1SBy{(41d06c1KTsq(TzTh(L* zNl554`W6=+Y}4uo0iFq!Pe=8f0g|m|y$jN}cBJ_CpM+C}Ya*Up*%H1oPb^E@XO zed_f6IKvJM9t_w@RldMH&^=bxB6tTrfe$N=<`S54ITf9pf+%D|QM6Ia3gF=~&K!Q;0Pf>+)W@-t z!Wj|OblLNj6eSgQ(!u@q%w~dDAiI4c_`@1TH^5a@pk|_`lIu}@=u5)q^9_)mbg%UN zG7Aqrk;Dqhq)j=}cf3|=${cGVlw;sI9|~P2j|+SUx!~T05%M5nM|-=#Bj}$di23t_M-7Am z`+)B^caMmHM*tlk-KoH1Bp)8uT!T6o0WNi{_v2@4G!Q2v@jMA7EbwgGvP>MVPAy0; z)7kPKC@7Hegfb&Am2AX$Vf9<{e zd5TiX-qK!E;f0M730RfYfrR90KpsoS$v0CWXpSgdfJ>h6lLh@)iO%a66J7>qiCao1 ztQ3#CkLPtM&P0+^5;=@?OOiAwO)4w-)$KI(u;cUO9>u=yIE>}KpXfJN&aX_2RCa3Fa6cN{3&MP!YO5^J$DEM|N^ZNjUSya#zhEazcshz- zp%^CO znUtdHxFpr7?>72!%ps5wph;Q zb5?;9e#rXuR$2eebepN6!8Jf;7I|rbp2J(D{?CsvbZWLaYKY^v`30eWp!m=5Mac2m1f@G2vwL z@BetK>BJqdq4>owO4SA%;hmXj5ujK4>$W*P*P zWQOs``gi^$n?|H2N!GVU#ojJBLLkdj4s&GD6)$B>#8%ce$Ghi5D`Fd;f+vfaaWpWL zkqEuSCA+Hf=40XTqV#*VQ7I|{**Ug3T0p0GDreW`i+0zcr6eck%_-$6)|$Pz{KSIz zIfAdT1UOM|ce}oGpI*q7FWS~C?;R>*=g5>YS8>!KwGGH794EB*kp#IZ+%jz9`4myJeM6B<`?S&I^)YKW#6@vB_9p=APw?;V3vBx?^|6fj~o+gz4`%r z(73obhMa-fMTlQAy*t5(!b)%FkH$drHW06RrG@rvL9@A$7AP}B4{&B7KUW4UP=gK$ z6hQ*btaPY4g_upxHz_l|^(U#!AtdoAOaomA+Ey4RqOr!3H)KRT+m`4)+}I8=UUwp4 z=6FnMh3X0S6$RpB&seAe4Zuioa_(dy)3j@s?lOA0NHX?ef0z8o{%Ba16JL+^&Vw^0 zg~~A}v$JU8#Dz!L?}j?>$*B+@4i2sDKbi!lH_bmn#1h4;0y-6Rzxl_}NR@)bkHtT9 z5sl(286oQrPU357798sXR(B^jx&f+G1+zAEf2(Y4%vha$UDHvmbnCtc-OB@LN@woprEEGjgZuMxUncj3+TFpO| z8WaGCzZXZ?r!c_(@%g{9&NeRA&KAz*jwS}i|EAQktY#y(DUR@cqgOA(x0b3H)v6@N zE7HNzCXa%Tx?a_+zB_`+pAIgn}agMv7qDc+e_VeX1jQpa~OMWy@)}l z_(hEcO5_u_SA7~<8#@m@{2C{=#Gga8nN4i@`FozcL8g+YAd4F(l%Xg{dFYFoim2SF znDYX$<)zPl8LVnXP5`)Ia3FMZ%_D8&gdlyv^|yx#{fvpBN^CINk6LZMoRw4`EU2QqPgn1__FV(Viwe zF7spEQ@+C=24ZpZxf#xmQXHu3);Xw!6ReA)h|W0ss%kG&Guu&Zwk6AP3yYI~_759N zH?SPHtEXSPV;hXT!H1meC;0M(4hp4d>7ahb(S>GM8fiQ7cz0su>6_uPf^pvSdN#0( z|DIjYuX@?^F5m8?zG;4gV=Bv~kH$KuurJn93X+U90{%Y1%No-@Y+v-B$2@hWq0~>nIu5xwQwj>Tmy|rs6HnNwwV_&!7sd$ZQgTtb($34ayN==#c)_#_M zoAeGj5>>lISHl#MNOHbMT9sH~1pN z!aeO?CxzFWhW}={k>|tIljSw)&~vv3@wA|$bvtLb98O3C79@q$0lRL}ex}8sIXgZS zg>%?>$=Wr<(FgDCx`$ltmsC@9qZXoRc1?=$lz0=O-&~3pZpjtRRJVXE_fAh36t59< zJMnO3(f6}?7^Wl2?j3pU{R~_5)zb*pcm{rsnoQ9*I;pZ&=)WjB6P$6CsVo6Z-WaOv zTwHombY*?;LO*5{^C$9Tu@ z>GF(rIbrog#fpv=b-U;LYY6FOV*YchK{T?Y5mB9tGN_9C;1xyKSn;kiOdojfR%i3L zWJvOJ$r}&i2$yNKL7t5R1J3=(wZ03oaLqsO_J>?1H4V$@odb2VA9 z@=4Xd+F&9%4-p;r-#r$|-aug!D zR-MSm!Z}WL#Yt2nh=wahGD*Iid2?xHPvM8UZEtX2{9`s|LQP15QW2`ccY{2`9O8cHOH>UXtt0X`#9wRv@MnUSB=G% zQ?bQ!^8Ev{QtyS&3rdw$bW~pLW&G4>V<(fktUFMjVQ>5{0-v0C{wcECf3SDN`Dym; zNT0ZX@sN%hrMc5rJa2_5Q7%Wq4Wm=cA_#Ncb{$pba0^1Ar?9*m3T;O9(K4BrvmwiD zxtoAR&LSDGk?ud*i0iA3F1iK;diwe9;aQ0($I)`qZRGaE6ss}|%!!UC9m-wG<`Zzh z^i0gS;butz;5PigfkE$sAE09e&S1Ne3I-U^K7ILU&1i3(Y_&mvUY8jbKH7`jI{xW$ z1n6>7l=9chA*$W)N#C*?arp@sDqx2k1{jPJLh}Dga`yAVJAP5@(&lhnS~e;M+FA+% z2IGcs%7KS^`w#61FSjj88wK7dnID3>4-CL;qqgTNfYrCWU7dq~stNoTB6w{tUI_8= zelQWd-7@X+fz{Vg<70r@+b$nub+GSAdwT&Js9ZDJ>!yE`zG}CB3~~?zAw{WRAdnyQ z2(YO3FoYKhf_8(U6MeL)-~kOmNOAqb zc2gl^owT9g&K4A!8Mzw@85yAs1!sFuPcV&vK=>6wFo*$Yy Date: Tue, 29 Jul 2014 20:58:05 -0700 Subject: [PATCH 021/170] [SPARK-2054][SQL] Code Generation for Expression Evaluation Adds a new method for evaluating expressions using code that is generated though Scala reflection. This functionality is configured by the SQLConf option `spark.sql.codegen` and is currently turned off by default. Evaluation can be done in several specialized ways: - *Projection* - Given an input row, produce a new row from a set of expressions that define each column in terms of the input row. This can either produce a new Row object or perform the projection in-place on an existing Row (MutableProjection). - *Ordering* - Compares two rows based on a list of `SortOrder` expressions - *Condition* - Returns `true` or `false` given an input row. For each of the above operations there is both a Generated and Interpreted version. When generation for a given expression type is undefined, the code generator falls back on calling the `eval` function of the expression class. Even without custom code, there is still a potential speed up, as loops are unrolled and code can still be inlined by JIT. This PR also contains a new type of Aggregation operator, `GeneratedAggregate`, that performs aggregation by using generated `Projection` code. Currently the required expression rewriting only works for simple aggregations like `SUM` and `COUNT`. This functionality will be extended in a future PR. This PR also performs several clean ups that simplified the implementation: - The notion of `Binding` all expressions in a tree automatically before query execution has been removed. Instead it is the responsibly of an operator to provide the input schema when creating one of the specialized evaluators defined above. In cases when the standard eval method is going to be called, binding can still be done manually using `BindReferences`. There are a few reasons for this change: First, there were many operators where it just didn't work before. For example, operators with more than one child, and operators like aggregation that do significant rewriting of the expression. Second, the semantics of equality with `BoundReferences` are broken. Specifically, we have had a few bugs where partitioning breaks because of the binding. - A copy of the current `SQLContext` is automatically propagated to all `SparkPlan` nodes by the query planner. Before this was done ad-hoc for the nodes that needed this. However, this required a lot of boilerplate as one had to always remember to make it `transient` and also had to modify the `otherCopyArgs`. Author: Michael Armbrust Closes #993 from marmbrus/newCodeGen and squashes the following commits: 96ef82c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen f34122d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen 67b1c48 [Michael Armbrust] Use conf variable in SQLConf object 4bdc42c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 41a40c9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen de22aac [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen fed3634 [Michael Armbrust] Inspectors are not serializable. ef8d42b [Michael Armbrust] comments 533fdfd [Michael Armbrust] More logging of expression rewriting for GeneratedAggregate. 3cd773e [Michael Armbrust] Allow codegen for Generate. 64b2ee1 [Michael Armbrust] Implement copy 3587460 [Michael Armbrust] Drop unused string builder function. 9cce346 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 1a61293 [Michael Armbrust] Address review comments. 0672e8a [Michael Armbrust] Address comments. 1ec2d6e [Michael Armbrust] Address comments 033abc6 [Michael Armbrust] off by default 4771fab [Michael Armbrust] Docs, more test coverage. d30fee2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen d2ad5c5 [Michael Armbrust] Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases. be2cd6b [Michael Armbrust] WIP: Remove old method for reference binding, more work on configuration. bc88ecd [Michael Armbrust] Style 6cc97ca [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 4220f1e [Michael Armbrust] Better config, docs, etc. ca6cc6b [Michael Armbrust] WIP 9d67d85 [Michael Armbrust] Fix hive planner fc522d5 [Michael Armbrust] Hook generated aggregation in to the planner. e742640 [Michael Armbrust] Remove unneeded changes and code. 675e679 [Michael Armbrust] Upgrade paradise. 0093376 [Michael Armbrust] Comment / indenting cleanup. d81f998 [Michael Armbrust] include schema for binding. 0e889e8 [Michael Armbrust] Use typeOf instead tq f623ffd [Michael Armbrust] Quiet logging from test suite. efad14f [Michael Armbrust] Remove some half finished functions. 92e74a4 [Michael Armbrust] add overrides a2b5408 [Michael Armbrust] WIP: Code generation with scala reflection. --- pom.xml | 10 + project/SparkBuild.scala | 11 +- sql/catalyst/pom.xml | 9 + .../spark/sql/catalyst/dsl/package.scala | 2 +- .../catalyst/expressions/BoundAttribute.scala | 50 +- .../sql/catalyst/expressions/Projection.scala | 39 +- .../spark/sql/catalyst/expressions/Row.scala | 40 +- .../sql/catalyst/expressions/ScalaUdf.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 468 ++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 76 +++ .../codegen/GenerateOrdering.scala | 98 ++++ .../codegen/GeneratePredicate.scala | 48 ++ .../codegen/GenerateProjection.scala | 219 ++++++++ .../expressions/codegen/package.scala | 80 +++ .../sql/catalyst/expressions/package.scala | 28 +- .../sql/catalyst/expressions/predicates.scala | 3 + .../apache/spark/sql/catalyst/package.scala | 27 + .../sql/catalyst/planning/patterns.scala | 71 +++ .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../sql/catalyst/plans/logical/commands.scala | 12 +- .../sql/catalyst/rules/RuleExecutor.scala | 5 +- .../spark/sql/catalyst/types/dataTypes.scala | 18 +- .../ExpressionEvaluationSuite.scala | 55 +- .../GeneratedEvaluationSuite.scala | 69 +++ .../GeneratedMutableEvaluationSuite.scala | 61 +++ .../optimizer/CombiningLimitsSuite.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 19 +- .../org/apache/spark/sql/SQLContext.scala | 25 +- .../spark/sql/api/java/JavaSQLContext.scala | 4 +- .../spark/sql/execution/Aggregate.scala | 13 +- .../apache/spark/sql/execution/Exchange.scala | 8 +- .../apache/spark/sql/execution/Generate.scala | 13 +- .../sql/execution/GeneratedAggregate.scala | 200 ++++++++ .../spark/sql/execution/SparkPlan.scala | 81 ++- .../spark/sql/execution/SparkStrategies.scala | 138 +++--- .../spark/sql/execution/basicOperators.scala | 44 +- .../spark/sql/execution/debug/package.scala | 8 +- .../apache/spark/sql/execution/joins.scala | 44 +- .../spark/sql/parquet/ParquetRelation.scala | 18 +- .../sql/parquet/ParquetTableOperations.scala | 14 +- .../spark/sql/parquet/ParquetTestData.scala | 9 +- .../org/apache/spark/sql/QueryTest.scala | 1 + .../spark/sql/execution/PlannerSuite.scala | 8 +- .../apache/spark/sql/execution/TgfSuite.scala | 2 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 5 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../hive/execution/ScriptTransformation.scala | 2 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 6 +- ...se null-0-8ef2f741400830ef889a9dd0c817fe3d | 1 + ...le case-0-f513687d17dcb18546fefa75000a52f2 | 1 + ...le case-0-c264e319c52f1840a32959d552b99e73 | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 11 + 53 files changed, 1889 insertions(+), 297 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala create mode 100644 sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d create mode 100644 sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 create mode 100644 sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 diff --git a/pom.xml b/pom.xml index 39538f9660623..ae97bf03c53a2 100644 --- a/pom.xml +++ b/pom.xml @@ -114,6 +114,7 @@ spark 2.10.4 2.10 + 2.0.1 0.18.1 shaded-protobuf org.spark-project.akka @@ -825,6 +826,15 @@ -target ${java.version} + + + + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0a6326e72297a..490fac3cc3646 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -167,6 +167,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst macro settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -189,10 +192,13 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } -object SQL { - +object Catalyst { lazy val settings = Seq( + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)) +} +object SQL { + lazy val settings = Seq( initialCommands in console := """ |import org.apache.spark.sql.catalyst.analysis._ @@ -207,7 +213,6 @@ object SQL { |import org.apache.spark.sql.test.TestSQLContext._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) - } object Hive { diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 531bfddbf237b..54fa96baa1e18 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -36,10 +36,19 @@ + + org.scala-lang + scala-compiler + org.scala-lang scala-reflect + + org.scalamacros + quasiquotes_${scala.binary.version} + ${scala.macros.version} + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5c8c810d9135a..f44521d6381c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -202,7 +202,7 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a) + def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f01056462..a3ebec8082cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.trees + import org.apache.spark.sql.Logging /** @@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, baseReference: Attribute) - extends Attribute with trees.LeafNode[Expression] { +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends Expression with trees.LeafNode[Expression] { type EvaluatedType = Any - override def nullable = baseReference.nullable - override def dataType = baseReference.dataType - override def exprId = baseReference.exprId - override def qualifiers = baseReference.qualifiers - override def name = baseReference.name + override def references = Set.empty - override def newInstance = BoundReference(ordinal, baseReference.newInstance) - override def withNullability(newNullability: Boolean) = - BoundReference(ordinal, baseReference.withNullability(newNullability)) - override def withQualifiers(newQualifiers: Seq[String]) = - BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) - - override def toString = s"$baseReference:$ordinal" + override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } -/** - * Used to denote operators that do their own binding of attributes internally. - */ -trait NoBind { self: trees.TreeNode[_] => } - -class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { - import BindReferences._ - - def apply(plan: TreeNode): TreeNode = { - plan.transform { - case n: NoBind => n.asInstanceOf[TreeNode] - case leafNode if leafNode.children.isEmpty => leafNode - case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => - bindReference(e, unaryNode.children.head.output) - } - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - // TODO: This fallback is required because some operators (such as ScriptTransform) - // produce new attributes that can't be bound. Likely the right thing to do is remove - // this rule and require all operators to explicitly bind to the input schema that - // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") - a + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } else { - BoundReference(ordinal, a) + BoundReference(ordinal, a.dataType, a.nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 2c71d2c7b3563..8fc5896974438 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions + /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. + * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -class Projection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) @@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of th - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { +case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] val mutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow - def apply(input: Row): Row = { + override def target(row: MutableRow): MutableProjection = { + mutableRow = row + this + } + + override def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,6 +77,12 @@ class JoinedRow extends Row { private[this] var row1: Row = _ private[this] var row2: Row = _ + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ def apply(r1: Row, r2: Row): Row = { row1 = r1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686cfe..7470cb861b83b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -88,15 +88,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - - /** - * Experimental - * - * Returns a mutable string builder for the specified column. A given row should return the - * result of any mutations made to the returned buffer next time getString is called for the same - * column. - */ - def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -180,6 +171,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + // Custom hashCode function that matches the efficient code generated version. + override def hashCode(): Int = { + var result: Int = 37 + + var i = 0 + while (i < values.length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + def copy() = this } @@ -187,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - def getStringBuilder(ordinal: Int): StringBuilder = ??? - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 5e089f7618e0a..acddf5e9c7004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def eval(input: Row): Any = { children.size match { + case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala new file mode 100644 index 0000000000000..5b398695bf560 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.google.common.cache.{CacheLoader, CacheBuilder} + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] + + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeparator = "$" + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = globalLock.synchronized { + create(in) + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = + apply(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + } + + /** + * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `false`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. + */ + protected case class EvaluatedExpression( + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) + + /** + * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that + * can be used to determine the result of evaluating the expression on an input row. + */ + def expressionEvaluator(e: Expression): EvaluatedExpression = { + val primitiveTerm = freshName("primitiveTerm") + val nullTerm = freshName("nullTerm") + val objectTerm = freshName("objectTerm") + + implicit class Evaluate1(e: Expression) { + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children + } + } + + implicit class Evaluate2(expressions: (Expression, Expression)) { + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + evaluateAs(expressions._1.dataType)(f) + + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) { + log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") + } + + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] + } + } + + val inputTuple = newTermName(s"i") + + // TODO: Skip generation of null handling code when expression are not nullable. + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + case b @ BoundReference(ordinal, dataType, nullable) => + val nullValue = q"$inputTuple.isNullAt($ordinal)" + q""" + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${getColumn(inputTuple, dataType, ordinal)} + """.children + + case expressions.Literal(null, dataType) => + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children + + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) + + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) + + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) + + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", IntegerType) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + ${eval.primitiveTerm}.toString + """.children + + case EqualTo(e1, e2) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + /* TODO: Fix null semantics. + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + */ + + case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + + case And(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && !${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = false + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = true + } + """.children + + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children + + case Not(child) => + // Uh, bad function name... + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + + case IsNotNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ + children.map { c => + val eval = expressionEvaluator(c) + q""" + if($nullTerm) { + ..${eval.code} + if(!${eval.nullTerm}) { + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} + } + } + """ + } + + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) + + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} + } else { + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} + } + """.children + } + + // If there was no match in the partial function above, we fall back on calling the interpreted + // expression evaluator. + val code: Seq[Tree] = + primitiveEvaluation.lift.apply(e).getOrElse { + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + } + + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + dataType match { + case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + } + } + + protected def setColumn( + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { + dataType match { + case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" + } + } + + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case StringType => "String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => ru.Literal(Constant("")) + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case IntegerType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: NativeType => n.tag + case _ => typeTag[Any] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala new file mode 100644 index 0000000000000..a419fd7ecb39b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[Row]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + val mutableRowName = newTermName("mutableRow") + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } + + val code = + q""" + () => { new $mutableProjectionType { + + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) + + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow + + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ + + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala new file mode 100644 index 0000000000000..4211998f7511a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NumericType} + +/** + * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of + * [[Expression Expressions]]. + */ +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { + val a = newTermName("a") + val b = newTermName("b") + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + + val compare = order.child.dataType match { + case _: NumericType => + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + } else { + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + } + } + + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + $compare + } + """ + } + + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 + } + } + new $orderingName() + """ + logger.debug(s"Generated Ordering: $code") + toolBox.eval(code).asInstanceOf[Ordering[Row]] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala new file mode 100644 index 0000000000000..2a0935c790cf3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + */ +object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + + protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = + BindReferences.bindReference(in, inputSchema) + + protected def create(predicate: Expression): ((Row) => Boolean) = { + val cEval = expressionEvaluator(predicate) + + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + } + """ + + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala new file mode 100644 index 0000000000000..77fa02c13de30 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + + +/** + * Generates bytecode that produces a new [[Row]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom + * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + */ +object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + // Make Mutablility optional... + protected def create(expressions: Seq[Expression]): Projection = { + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + final def setNullAt(i: Int) = { nullBits(i) = true } + final def isNullAt(i: Int) = nullBits(i) + """.children + + val tupleElements = expressions.zipWithIndex.flatMap { + case (e, i) => + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + { + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else + $elementName = ${evaluatedExpression.primitiveTerm} + } + """.children : Seq[Tree] + } + + val iteratorFunction = { + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + q"final def iterator = Iterator[Any](..$allColumns)" + } + + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + return + } + }""" + } + q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } + + val specificAccessorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil + } + + q""" + final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + + val specificMutatorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { $elementName = value; return }" :: Nil + case _ => Nil + } + + q""" + final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } + + val hashValues = expressions.zipWithIndex.map { case (e,i) => + val elementName = newTermName(s"c$i") + val nonNull = e.dataType match { + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case DoubleType => + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" + } + q"if (isNullAt($i)) 0 else $nonNull" + } + + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } + """ + + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" + } + + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ + + val copyFunction = + q""" + final def copy() = new $genericRowType(this.toArray) + """ + + val classBody = + nullFunctions ++ ( + lengthDef +: + iteratorFunction +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + copyFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody + } + + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + """ + + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala new file mode 100644 index 0000000000000..80c7dfd376c96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.util + +/** + * A collection of generators that build custom bytecode at runtime for performing the evaluation + * of catalyst expression. + */ +package object codegen { + + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock + + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ + object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { + val batches = + Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil + + object CleanExpressions extends rules.Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case Alias(c, _) => c + } + } + } + + /** + * :: DeveloperApi :: + * Dumps the bytecode from a class to the screen using javap. + */ + @DeveloperApi + object DumpByteCode { + import scala.sys.process._ + val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + dumpDirectory.mkdir() + + def apply(obj: Any): Unit = { + val generatedClass = obj.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName) + if (!packageDir.exists()) { packageDir.mkdir() } + + val classFile = + new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class") + + val outfile = new java.io.FileOutputStream(classFile) + outfile.write(generatedBytes) + outfile.close() + + println( + s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index b6f2451b52e1f..55d95991c5f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst * ==Evaluation== * The result of expressions can be evaluated using the `Expression.apply(Row)` method. */ -package object expressions +package object expressions { + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + */ + abstract class Projection extends (Row => Row) + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. + */ + abstract class MutableProjection extends Projection { + def currentValue: Row + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06b94a98d3cd0..5976b0ddf3e03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType object InterpretedPredicate { + def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(expression, inputSchema)) + def apply(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 0000000000000..3b3e206055cfc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object catalyst { + /** + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. + */ + protected[catalyst] object ScalaReflectionLock +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe067d..418f8686bfe5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -104,6 +104,77 @@ object PhysicalOperation extends PredicateHelper { } } +/** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ +object PartialAggregation { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child)) + } else { + None + } + case _ => None + } +} + + /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ac85f95b52a2f..888cb08e95f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -112,7 +112,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => override lazy val statistics: Statistics = - throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics") + throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index a357c6ffb8977..481a5a4f212b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -35,7 +35,7 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) + Seq(AttributeReference("result", StringType, nullable = false)()) } /** @@ -43,7 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(1, AttributeReference("", StringType, nullable = false)())) + AttributeReference("", StringType, nullable = false)()) } /** @@ -52,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman */ case class ExplainCommand(plan: LogicalPlan) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) + Seq(AttributeReference("plan", StringType, nullable = false)()) } /** @@ -71,7 +71,7 @@ case class DescribeCommand( isExtended: Boolean) extends Command { override def output = Seq( // Column names are based on Hive. - BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), - BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), - BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe146..e300bdbececbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -72,7 +72,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + // Only log if this is a rule that is supposed to run more than once. + if (iteration != 2) { + logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + } continue = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cd4b5e9c1b529..71808f76d632b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -23,16 +23,13 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * A JVM-global lock that should be used to prevent thread safety issues when using things in - * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for - * 2.10.* builds. See SI-6240 for more details. + * Utility functions for working with DataTypes. */ -protected[catalyst] object ScalaReflectionLock - object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -99,6 +96,13 @@ abstract class DataType { case object NullType extends DataType +object NativeType { + def all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) +} + trait PrimitiveType extends DataType { override def isPrimitive = true } @@ -149,6 +153,10 @@ abstract class NumericType extends NativeType with PrimitiveType { val numeric: Numeric[JvmType] } +object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + /** Matcher for any expressions that evaluate to [[IntegralType]]s */ object IntegralType { def unapply(a: Expression): Boolean = a match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 58f8c341e6676..999c9fff38d60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).eval(null) === 2) + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(0L), 0L) + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal(1) + Literal(1), 2) } /** @@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = ! Literal(v, BooleanType) - val result = expr.eval(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + checkEvaluation(!Literal(v, BooleanType), answer) + } } booleanLogicTest("AND", _ && _, @@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23) - checkEvaluation("23" cast DecimalType, 23) - checkEvaluation("23" cast ByteType, 23) - checkEvaluation("23" cast ShortType, 23) + checkEvaluation("23" cast FloatType, 23f) + checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast ByteType, 23.toByte) + checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} @@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) + checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( @@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS)()), "a").nullable === true) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false) + assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 0000000000000..245a2e148030c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = GenerateProjection.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10.seconds)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 0000000000000..887aabb1d5fb4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b955f01..e2ae0d25db1a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5d85a0fd4eebb..2d407077be303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,8 +24,11 @@ import scala.collection.JavaConverters._ object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + val CODEGEN_ENABLED = "spark.sql.codegen" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -56,6 +59,18 @@ trait SQLConf { /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt + /** + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (get(CODEGEN_ENABLED, "false") == "true") true else false + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 @@ -111,5 +126,5 @@ trait SQLConf { private[spark] def clear() { settings.clear() } - } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c2bdef732372c..e4b6810180994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. @@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext) conf: Configuration = new Configuration()): SchemaRDD = { new SchemaRDD( this, - ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf)) + ParquetRelation.createEmpty( + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) @@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 806097c917b91..85726bae54911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { conf: Configuration = new Configuration()): JavaSchemaRDD = { new JavaSchemaRDD( sqlContext, - ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf)) + ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext)) } /** @@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD( sqlContext, - ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8bfa404a..463a1d32d7fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,8 +42,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + child: SparkPlan) + extends UnaryNode { override def requiredChildDistribution = if (partial) { @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output @@ -138,7 +136,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +150,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +173,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e798a..392a7f3be3904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning @@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00262dbb..c386fd121c5de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,23 +47,26 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -72,7 +75,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala new file mode 100644 index 0000000000000..4a26934c49c93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + +/** + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class GeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def output = aggregateExpressions.map(_.toAttribute) + + override def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } + + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + } + + val computationSchema = computeFunctions.flatMap(_.schema) + + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext = resultIterator.hasNext + + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 77c874d0315ee..21cbbc9772a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,22 +18,55 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row, SQLContext} + + +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() + + protected def sparkContext = sqlContext.sparkContext + + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } + + /** Overridden make copy also propogates sqlContext to copied plan. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + SparkPlan.currentContext.set(sqlContext) + super.makeCopy(newArgs) + } + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } + } + + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 404d48ae05b45..5f1fe99f75c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.util.Try - import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = - if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) { + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { BuildRight } else { BuildLeft @@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + // Aggregations that can be performed in two phases, before and after the shuffle. - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - - // Construct two phased aggregation. - execution.Aggregate( + // Cases where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => + execution.GeneratedAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + planLater(child))) :: Nil + + // Cases where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), joinType, condition) :: Nil case _ => Nil } } @@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil + case logical.Except(left, right) => + execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 966d8f95fc83c..174eda8f1a72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** @@ -189,15 +187,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d2f6930..5ef46c32d44bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11caae838..2750ddbce896f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -217,9 +219,8 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + right: SparkPlan) extends BinaryNode with HashJoin { - override def otherCopyArgs = sqlContext :: Nil override def outputPartitioning: Partitioning = left.outputPartitioning @@ -228,7 +229,7 @@ case class BroadcastHashJoin( @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -248,14 +249,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -271,7 +269,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -300,8 +298,14 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } } } @@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - override def output = { joinType match { case LeftOuter => @@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 @@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin( // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() matched = true includedBroadcastTuples += i } @@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin( } if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + matchedRows += joinedRow(streamedRow, rightNulls).copy() } } Iterator((matchedRows, includedBroadcastTuples)) @@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin( streamedPlusMatches.map(_._2).reduce(_ ++ _) } + val leftNulls = new GenericMutableRow(left.output.size) val rightOuterMatches: Seq[Row] = if (joinType == RightOuter || joinType == FullOuter) { broadcastedRelation.value.zipWithIndex.filter { case (row, i) => !allIncludedBroadcastTuples.contains(i) }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + case (row, _) => new JoinedRow(leftNulls, row) } } else { Vector() } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 8c7dbd5eb4a09..b3bae5db0edbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) + @transient conf: Option[Configuration], + @transient sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation { self: Product => @@ -61,7 +62,7 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - override def newInstance = ParquetRelation(path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, @@ -70,6 +71,9 @@ private[sql] case class ParquetRelation( p.path == path && p.output == output case _ => false } + + // TODO: Use data from the footers. + override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -106,13 +110,14 @@ private[sql] object ParquetRelation { */ def create(pathString: String, child: LogicalPlan, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } - createEmpty(pathString, child.output, false, conf) + createEmpty(pathString, child.output, false, conf, sqlContext) } /** @@ -127,14 +132,15 @@ private[sql] object ParquetRelation { def createEmpty(pathString: String, attributes: Seq[Attribute], allowExisting: Boolean, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf)) { + new ParquetRelation(path.toString, Some(conf), sqlContext) { override val output = attributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ea74320d06c86..912a9f002b7d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -55,8 +55,7 @@ case class ParquetTableScan( // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { override def execute(): RDD[Row] = { @@ -99,8 +98,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -110,7 +107,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -150,8 +147,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -171,7 +167,7 @@ case class InsertIntoParquetTable( val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { - logger.debug("Initializing MutableRowWriteSupport") + log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { classOf[org.apache.spark.sql.parquet.RowWriteSupport] @@ -203,8 +199,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d4599da711254..837ea7695dbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup @@ -103,7 +104,7 @@ private[sql] object ParquetTestData { val testDir = Utils.createTempDir() val testFilterDir = Utils.createTempDir() - lazy val testData = new ParquetRelation(testDir.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) val testNestedSchema1 = // based on blogpost example, source: @@ -202,8 +203,10 @@ private[sql] object ParquetTestData { val testNestedDir3 = Utils.createTempDir() val testNestedDir4 = Utils.createTempDir() - lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) - lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + lazy val testNestedData1 = + new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) + lazy val testNestedData2 = + new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) def writeFile() = { testDir.delete() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e1971d968b..1fd8d27b34c59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e852eb2..76b1724471442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b8ed15a..2cab5e0c44d92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3c911e9a4e7b1..561f5b4a49965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { + SparkPlan.currentContext.set(TestSQLContext) val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext) + Seq()) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 84d43eaeea51d..f0a61270daf05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c2b0b00aa5852..39033bdeac4b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -131,7 +131,7 @@ case class InsertIntoHiveTable( conf, SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) writer.preSetup() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5fef0eb..0c8f676e9c5c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 057eb60a02612..7582b4743d404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) @@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aadfd2e900151..89cc589fb8001 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} @@ -30,6 +32,15 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") + createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") From 2e6efcacea19bddbdae1d655ef54186f2e52747f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 29 Jul 2014 22:16:20 -0700 Subject: [PATCH 022/170] [SPARK-2568] RangePartitioner should run only one job if data is balanced As of Spark 1.0, RangePartitioner goes through data twice: once to compute the count and once to do sampling. As a result, to do sortByKey, Spark goes through data 3 times (once to count, once to sample, and once to sort). `RangePartitioner` should go through data only once, collecting samples from input partitions as well as counting. If the data is balanced, this should give us a good sketch. If we see big partitions, we re-sample from them in order to collect enough items. The downside is that we need to collect more from each partition in the first pass. An alternative solution is caching the intermediate result and decide whether to fetch the data after. Author: Xiangrui Meng Author: Reynold Xin Closes #1562 from mengxr/range-partitioner and squashes the following commits: 6cc2551 [Xiangrui Meng] change foreach to for eb39b08 [Xiangrui Meng] Merge branch 'master' into range-partitioner eb95dd8 [Xiangrui Meng] separate sketching and determining bounds impl c436d30 [Xiangrui Meng] fix binary metrics unit tests db58a55 [Xiangrui Meng] add unit tests a6e35d6 [Xiangrui Meng] minor update 60be09e [Xiangrui Meng] remove importance sampler 9ee9992 [Xiangrui Meng] update range partitioner to run only one job on roughly balanced data cc12f47 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into range-part 06ac2ec [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into range-part 17bcbf3 [Reynold Xin] Added seed. badf20d [Reynold Xin] Renamed the method. 6940010 [Reynold Xin] Reservoir sampling implementation. --- .../scala/org/apache/spark/Partitioner.scala | 121 +++++++++++++++--- .../org/apache/spark/PartitioningSuite.scala | 64 ++++++++- .../scala/org/apache/spark/rdd/RDDSuite.scala | 5 + 3 files changed, 171 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 52c018baa5f7b..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,11 +19,15 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.reflect.ClassTag +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} +import scala.util.hashing.byteswap32 -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} +import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -103,26 +107,49 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ascending: Boolean = true) extends Partitioner { + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. + require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { - if (partitions == 1) { - Array() + if (partitions <= 1) { + Array.empty } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted - if (rddSample.length == 0) { - Array() + // This is the sample size we need to have roughly balanced output partitions, capped at 1M. + val sampleSize = math.min(20.0 * partitions, 1e6) + // Assume the input partitions are roughly balanced and over-sample a little bit. + val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt + val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) + if (numItems == 0L) { + Array.empty } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) + // If a partition contains much more than the average number of items, we re-sample from it + // to ensure that enough items are collected from that partition. + val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) + val candidates = ArrayBuffer.empty[(K, Float)] + val imbalancedPartitions = mutable.Set.empty[Int] + sketched.foreach { case (idx, n, sample) => + if (fraction * n > sampleSizePerPartition) { + imbalancedPartitions += idx + } else { + // The weight is 1 over the sampling probability. + val weight = (n.toDouble / sample.size).toFloat + for (key <- sample) { + candidates += ((key, weight)) + } + } + } + if (imbalancedPartitions.nonEmpty) { + // Re-sample imbalanced partitions with the desired sampling probability. + val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) + val seed = byteswap32(-rdd.id - 1) + val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() + val weight = (1.0 / fraction).toFloat + candidates ++= reSampled.map(x => (x, weight)) } - bounds + RangePartitioner.determineBounds(candidates, partitions) } } } @@ -212,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } } + +private[spark] object RangePartitioner { + + /** + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ + def sketch[K:ClassTag]( + rdd: RDD[K], + sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + val shift = rdd.id + // val classTagK = classTag[K] // to avoid serializing the entire partitioner object + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx ^ (shift << 16)) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter, sampleSizePerPartition, seed) + Iterator((idx, n, sample)) + }.collect() + val numItems = sketched.map(_._2.toLong).sum + (numItems, sketched) + } + + /** + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ + def determineBounds[K:Ordering:ClassTag]( + candidates: ArrayBuffer[(K, Float)], + partitions: Int): Array[K] = { + val ordering = implicitly[Ordering[K]] + val ordered = candidates.sortBy(_._1) + val numCandidates = ordered.size + val sumWeights = ordered.map(_._2.toDouble).sum + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < numCandidates) && (j < partitions - 1)) { + val (key, weight) = ordered(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 + } + bounds.toArray + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4658a08064280..fc0cee3e8749d 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer import scala.math.abs import org.scalatest.{FunSuite, PrivateMethodTester} @@ -52,14 +53,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(p2 === p2) assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) + assert(p2 === p4) assert(p4 === anotherP4) assert(anotherP4 === p4) assert(descendingP2 === descendingP2) assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) + assert(descendingP2 === descendingP4) assert(p2 != descendingP2) assert(p4 != descendingP4) assert(descendingP2 != p2) @@ -102,6 +101,63 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet partitioner.getPartition(Row(100)) } + test("RangPartitioner.sketch") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(i)(random.nextDouble()) + }.cache() + val sampleSizePerPartition = 10 + val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition) + assert(count === rdd.count()) + sketched.foreach { case (idx, n, sample) => + assert(n === idx) + assert(sample.size === math.min(n, sampleSizePerPartition)) + } + } + + test("RangePartitioner.determineBounds") { + assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty, + "Bounds on an empty candidates set should be empty.") + val candidates = ArrayBuffer( + (0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f)) + assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7)) + } + + test("RangePartitioner should run only one job if data is roughly balanced") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(5000 * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(10, 20, 40)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should work well on unbalanced data") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(20 * i * i * i)((random.nextDouble() + i, i)) + }.cache() + for (numPartitions <- Seq(2, 4, 8)) { + val partitioner = new RangePartitioner(numPartitions, rdd) + assert(partitioner.numPartitions === numPartitions) + val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values + assert(counts.max < 3.0 * counts.min) + } + } + + test("RangePartitioner should return a single partition for empty RDDs") { + val empty1 = sc.emptyRDD[(Int, Double)] + val partitioner1 = new RangePartitioner(0, empty1) + assert(partitioner1.numPartitions === 1) + val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)]) + val partitioner2 = new RangePartitioner(2, empty2) + assert(partitioner2.numPartitions === 1) + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6654ec2d7c656..fdc83bc0a5f8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -613,6 +613,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sort an empty RDD") { + val data = sc.emptyRDD[Int] + assert(data.sortBy(x => x).collect() === Array.empty) + } + test("sortByKey") { val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) From 077f633b4720422c5efbf0382e869ead3dc49612 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 29 Jul 2014 22:42:54 -0700 Subject: [PATCH 023/170] [SQL] Handle null values in debug() Author: Michael Armbrust Closes #1646 from marmbrus/nullDebug and squashes the following commits: 49050a8 [Michael Armbrust] Handle null values in debug() --- .../scala/org/apache/spark/sql/execution/debug/package.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5ef46c32d44bc..f31df051824d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -105,7 +105,9 @@ package object debug { var i = 0 while (i < numColumns) { val value = currentRow(i) - columnStats(i).elementTypes += HashSet(value.getClass.getName) + if (value != null) { + columnStats(i).elementTypes += HashSet(value.getClass.getName) + } i += 1 } currentRow From 4ce92ccaf761e48a10fc4fe4927dbfca858ca22b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 29 Jul 2014 23:52:09 -0700 Subject: [PATCH 024/170] [SPARK-2260] Fix standalone-cluster mode, which was broken The main thing was that spark configs were not propagated to the driver, and so applications that do not specify `master` or `appName` automatically failed. This PR fixes that and a couple of miscellaneous things that are related. One thing that may or may not be an issue is that the jars must be available on the driver node. In `standalone-cluster` mode, this effectively means these jars must be available on all the worker machines, since the driver is launched on one of them. The semantics here are not the same as `yarn-cluster` mode, where all the relevant jars are uploaded to a distributed cache automatically and shipped to the containers. This is probably not a concern, but still worth a mention. Author: Andrew Or Closes #1538 from andrewor14/standalone-cluster and squashes the following commits: 8c11a0d [Andrew Or] Clean up imports / comments (minor) 2678d13 [Andrew Or] Handle extraJavaOpts properly 7660547 [Andrew Or] Merge branch 'master' of github.com:apache/spark into standalone-cluster 6f64a9b [Andrew Or] Revert changes in YARN 2f2908b [Andrew Or] Fix tests ed01491 [Andrew Or] Don't go overboard with escaping 8e105e1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into standalone-cluster b890949 [Andrew Or] Abstract usages of converting spark opts to java opts 79f63a3 [Andrew Or] Move sparkProps into javaOpts 78752f8 [Andrew Or] Fix tests 5a9c6c7 [Andrew Or] Fix line too long c141a00 [Andrew Or] Don't display "unknown app" on driver log pages d7e2728 [Andrew Or] Avoid deprecation warning in standalone Client 6ceb14f [Andrew Or] Allow relevant configs to propagate to standalone Driver 7f854bc [Andrew Or] Fix test 855256e [Andrew Or] Fix standalone-cluster mode fd9da51 [Andrew Or] Formatting changes (minor) --- .../scala/org/apache/spark/SparkConf.scala | 22 ++++++++++++++++++- .../org/apache/spark/deploy/Client.scala | 21 +++++++++--------- .../org/apache/spark/deploy/Command.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 12 +++++----- .../spark/deploy/client/TestClient.scala | 6 ++--- .../spark/deploy/worker/CommandUtils.scala | 7 +++--- .../spark/deploy/worker/DriverRunner.scala | 3 ++- .../spark/deploy/worker/ExecutorRunner.scala | 14 +++++++----- .../spark/deploy/worker/ui/LogPage.scala | 11 +++++----- .../CoarseGrainedExecutorBackend.scala | 9 ++++++-- .../cluster/SparkDeploySchedulerBackend.scala | 11 ++++++---- .../scala/org/apache/spark/util/Utils.scala | 9 ++++++++ .../spark/deploy/JsonProtocolSuite.scala | 6 ++--- .../spark/deploy/SparkSubmitSuite.scala | 7 ++++-- .../deploy/worker/DriverRunnerTest.scala | 2 +- .../deploy/worker/ExecutorRunnerTest.scala | 2 +- 16 files changed, 93 insertions(+), 51 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ce4b91cae8ae..38700847c80f4 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -40,6 +40,8 @@ import scala.collection.mutable.HashMap */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { + import SparkConf._ + /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) @@ -198,7 +200,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * * E.g. spark.akka.option.x.y.x = "value" */ - getAll.filter {case (k, v) => k.startsWith("akka.")} + getAll.filter { case (k, _) => isAkkaConf(k) } /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) @@ -292,3 +294,21 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } + +private[spark] object SparkConf { + /** + * Return whether the given config is an akka config (e.g. akka.actor.provider). + * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). + */ + def isAkkaConf(name: String): Boolean = name.startsWith("akka.") + + /** + * Return whether the given config should be passed to an executor on start-up. + * + * Certain akka and authentication configs are required of the executor when it connects to + * the scheduler, while the rest of the spark configs can be inherited from the driver later. + */ + def isExecutorStartupConf(name: String): Boolean = { + isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c371dc3a51c73..17c507af2652d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ -import scala.collection.mutable.Map import scala.concurrent._ import akka.actor._ @@ -50,9 +48,6 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // TODO: We could add an env variable here and intercept it in `sc.addJar` that would // truncate filesystem paths similar to what YARN does. For now, we just require // people call `addJar` assuming the jar is in the same directory. - val env = Map[String, String]() - System.getenv().foreach{case (k, v) => env(k) = v} - val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" val classPathConf = "spark.driver.extraClassPath" @@ -65,10 +60,13 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends cp.split(java.io.File.pathSeparator) } - val javaOptionsConf = "spark.driver.extraJavaOptions" - val javaOpts = sys.props.get(javaOptionsConf) + val extraJavaOptsConf = "spark.driver.extraJavaOptions" + val extraJavaOpts = sys.props.get(extraJavaOptsConf) + .map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, env, classPathEntries, libraryPathEntries, javaOpts) + driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -109,6 +107,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends // Exception, if present statusResponse.exception.map { e => println(s"Exception from cluster was: $e") + e.printStackTrace() System.exit(-1) } System.exit(0) @@ -141,8 +140,10 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends */ object Client { def main(args: Array[String]) { - println("WARNING: This client is deprecated and will be removed in a future version of Spark.") - println("Use ./bin/spark-submit with \"--master spark://host:port\"") + if (!sys.props.contains("SPARK_SUBMIT")) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark") + println("Use ./bin/spark-submit with \"--master spark://host:port\"") + } val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala index 32f3ba385084f..a2b263544c6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Command.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala @@ -25,5 +25,5 @@ private[spark] case class Command( environment: Map[String, String], classPathEntries: Seq[String], libraryPathEntries: Seq[String], - extraJavaOptions: Option[String] = None) { + javaOpts: Seq[String]) { } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index c9cec33ebaa66..3df811c4ac5df 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -136,8 +136,6 @@ object SparkSubmit { (clusterManager, deployMode) match { case (MESOS, CLUSTER) => printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") - case (STANDALONE, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Standalone clusters.") case (_, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") case (_, CLUSTER) if isShell(args.primaryResource) => @@ -170,9 +168,9 @@ object SparkSubmit { val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.master"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.app.name"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), // Standalone cluster only OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), @@ -203,9 +201,9 @@ object SparkSubmit { sysProp = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraLibraryPath"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, CLIENT, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, CLIENT, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.files") diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index e15a87bd38fda..b8ffa9afb69cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -46,11 +46,11 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( - "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), - Seq()), Some("dummy-spark-home"), "ignored") + "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), + Seq(), Seq(), Seq()), Some("dummy-spark-home"), "ignored") val listener = new TestListener val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 4af5bc3afad6c..687e492a0d6fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -47,7 +47,6 @@ object CommandUtils extends Logging { */ def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") - val extraOpts = command.extraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq()) // Exists for backwards compatibility with older Spark versions val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString) @@ -62,7 +61,7 @@ object CommandUtils extends Logging { val joined = command.libraryPathEntries.mkString(File.pathSeparator) Seq(s"-Djava.library.path=$joined") } else { - Seq() + Seq() } val permGenOpt = Seq("-XX:MaxPermSize=128m") @@ -71,11 +70,11 @@ object CommandUtils extends Logging { val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=command.environment) + extraEnvironment = command.environment) val userClassPath = command.classPathEntries ++ Seq(classPath) Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ libraryOpts ++ extraOpts ++ workerLocalOpts ++ memoryOpts + permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 662d37871e7a6..5caaf6bea3575 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState /** * Manages the execution of one driver, including automatically restarting the driver on failure. + * This is currently only used in standalone cluster deploy mode. */ private[spark] class DriverRunner( val driverId: String, @@ -81,7 +82,7 @@ private[spark] class DriverRunner( driverDesc.command.environment, classPath, driverDesc.command.libraryPathEntries, - driverDesc.command.extraJavaOptions) + driverDesc.command.javaOpts) val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, sparkHome.getAbsolutePath) launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 467317dd9b44c..7be89f9aff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. + * This is currently only used in standalone mode. */ private[spark] class ExecutorRunner( val appId: String, @@ -72,7 +73,7 @@ private[spark] class ExecutorRunner( } /** - * kill executor process, wait for exit and notify worker to update resource status + * Kill executor process, wait for exit and notify worker to update resource status. * * @param message the exception message which caused the executor's death */ @@ -114,10 +115,13 @@ private[spark] class ExecutorRunner( } def getCommandSeq = { - val command = Command(appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment, - appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, - appDesc.command.extraJavaOptions) + val command = Command( + appDesc.command.mainClass, + appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.environment, + appDesc.command.classPathEntries, + appDesc.command.libraryPathEntries, + appDesc.command.javaOpts) CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index b389cb546de6c..ecb358c399819 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker.ui -import java.io.File import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -25,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils import org.apache.spark.Logging -import org.apache.spark.util.logging.{FileAppender, RollingFileAppender} +import org.apache.spark.util.logging.RollingFileAppender private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker @@ -64,11 +63,11 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val (logDir, params) = (appId, executorId, driverId) match { + val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e") + (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e", s"$a/$e") case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/", s"driverId=$d") + (s"${workDir.getPath}/$d/", s"driverId=$d", d) case _ => throw new Exception("Request must specify either application or driver identifiers") } @@ -120,7 +119,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w - UIUtils.basicSparkPage(content, logType + " log page for " + appId.getOrElse("unknown app")) + UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b455c9fcf4bd6..860b47e056451 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -98,8 +98,13 @@ private[spark] class CoarseGrainedExecutorBackend( } private[spark] object CoarseGrainedExecutorBackend extends Logging { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int, - workerUrl: Option[String]) { + + private def run( + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + workerUrl: Option[String]) { SignalLogger.register(log) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bf2dc88e29048..48aaaa54bdb35 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} @@ -46,6 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") + .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => cp.split(java.io.File.pathSeparator) } @@ -54,9 +55,11 @@ private[spark] class SparkDeploySchedulerBackend( cp.split(java.io.File.pathSeparator) } - val command = Command( - "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, - classPathEntries, libraryPathEntries, extraJavaOpts) + // Start executors with a few necessary configs for registering with the scheduler + val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", + args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) val sparkHome = sc.getSparkHome() val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cbb9050f393b..69f65b4bdccb1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1313,4 +1313,13 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** + * Convert all spark properties set in the given SparkConf to a sequence of java options. + */ + def sparkJavaOpts(conf: SparkConf, filterKey: (String => Boolean) = _ => true): Seq[String] = { + conf.getAll + .filter { case (k, _) => filterKey(k) } + .map { case (k, v) => s"-D$k=$v" } + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 01ab2d549325c..093394ad6d142 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -88,7 +88,7 @@ class JsonProtocolSuite extends FunSuite { } def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq()) + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) new ApplicationDescription("name", Some(4), 1234, cmd, Some("sparkHome"), "appUiUrl") } @@ -101,7 +101,7 @@ class JsonProtocolSuite extends FunSuite { def createDriverCommand() = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Some("-Dfoo") + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, @@ -170,7 +170,7 @@ object JsonConstants { """ |{"name":"name","cores":4,"memoryperslave":1234, |"user":"%s","sparkhome":"sparkHome", - |"command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),None)"} + |"command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),List())"} """.format(System.getProperty("user.name", "")).stripMargin val executorRunnerJsonStr = diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index f497a5e0a14f0..a301cbd48a0c3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -200,9 +200,12 @@ class SparkSubmitSuite extends FunSuite with Matchers { childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") mainClass should be ("org.apache.spark.deploy.Client") classpath should have size (0) - sysProps should have size (3) - sysProps.keys should contain ("spark.jars") + sysProps should have size (5) sysProps.keys should contain ("SPARK_SUBMIT") + sysProps.keys should contain ("spark.master") + sysProps.keys should contain ("spark.app.name") + sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.shuffle.spill") sysProps("spark.shuffle.spill") should be ("false") } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 4633bc3f7f25e..c930839b47f11 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -29,7 +29,7 @@ import org.apache.spark.deploy.{Command, DriverDescription} class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { - val command = new Command("mainClass", Seq(), Map(), Seq(), Seq()) + val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription, null, "akka://1.2.3.4/worker/") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index e5f748d55500d..ca4d987619c91 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -29,7 +29,7 @@ class ExecutorRunnerTest extends FunSuite { def f(s:String) = new File(s) val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")) val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq()), + Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")), From 7003c163dbb46bb7313aab130a33486a356435a8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 30 Jul 2014 00:15:31 -0700 Subject: [PATCH 025/170] [SPARK-2179][SQL] Public API for DataTypes and Schema The current PR contains the following changes: * Expose `DataType`s in the sql package (internal details are private to sql). * Users can create Rows. * Introduce `applySchema` to create a `SchemaRDD` by applying a `schema: StructType` to an `RDD[Row]`. * Add a function `simpleString` to every `DataType`. Also, the schema represented by a `StructType` can be visualized by `printSchema`. * `ScalaReflection.typeOfObject` provides a way to infer the Catalyst data type based on an object. Also, we can compose `typeOfObject` with some custom logics to form a new function to infer the data type (for different use cases). * `JsonRDD` has been refactored to use changes introduced by this PR. * Add a field `containsNull` to `ArrayType`. So, we can explicitly mark if an `ArrayType` can contain null values. The default value of `containsNull` is `false`. New APIs are introduced in the sql package object and SQLContext. You can find the scaladoc at [sql package object](http://yhuai.github.io/site/api/scala/index.html#org.apache.spark.sql.package) and [SQLContext](http://yhuai.github.io/site/api/scala/index.html#org.apache.spark.sql.SQLContext). An example of using `applySchema` is shown below. ```scala import org.apache.spark.sql._ val sqlContext = new org.apache.spark.sql.SQLContext(sc) val schema = StructType( StructField("name", StringType, false) :: StructField("age", IntegerType, true) :: Nil) val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Row(p(0), p(1).trim.toInt)) val peopleSchemaRDD = sqlContext. applySchema(people, schema) peopleSchemaRDD.printSchema // root // |-- name: string (nullable = false) // |-- age: integer (nullable = true) peopleSchemaRDD.registerAsTable("people") sqlContext.sql("select name from people").collect.foreach(println) ``` I will add new contents to the SQL programming guide later. JIRA: https://issues.apache.org/jira/browse/SPARK-2179 Author: Yin Huai Closes #1346 from yhuai/dataTypeAndSchema and squashes the following commits: 1d45977 [Yin Huai] Clean up. a6e08b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema c712fbf [Yin Huai] Converts types of values based on defined schema. 4ceeb66 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema e5f8df5 [Yin Huai] Scaladoc. 122d1e7 [Yin Huai] Address comments. 03bfd95 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 2476ed0 [Yin Huai] Minor updates. ab71f21 [Yin Huai] Format. fc2bed1 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema bd40a33 [Yin Huai] Address comments. 991f860 [Yin Huai] Move "asJavaDataType" and "asScalaDataType" to DataTypeConversions.scala. 1cb35fe [Yin Huai] Add "valueContainsNull" to MapType. 3edb3ae [Yin Huai] Python doc. 692c0b9 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 1d93395 [Yin Huai] Python APIs. 246da96 [Yin Huai] Add java data type APIs to javadoc index. 1db9531 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema d48fc7b [Yin Huai] Minor updates. 33c4fec [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema b9f3071 [Yin Huai] Java API for applySchema. 1c9f33c [Yin Huai] Java APIs for DataTypes and Row. 624765c [Yin Huai] Tests for applySchema. aa92e84 [Yin Huai] Update data type tests. 8da1a17 [Yin Huai] Add Row.fromSeq. 9c99bc0 [Yin Huai] Several minor updates. 1d9c13a [Yin Huai] Update applySchema API. 85e9b51 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema e495e4e [Yin Huai] More comments. 42d47a3 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema c3f4a02 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 2e58dbd [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema b8b7db4 [Yin Huai] 1. Move sql package object and package-info to sql-core. 2. Minor updates on APIs. 3. Update scala doc. 68525a2 [Yin Huai] Update JSON unit test. 3209108 [Yin Huai] Add unit tests. dcaf22f [Yin Huai] Add a field containsNull to ArrayType to indicate if an array can contain null values or not. If an ArrayType is constructed by "ArrayType(elementType)" (the existing constructor), the value of containsNull is false. 9168b83 [Yin Huai] Update comments. fc649d7 [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema eca7d04 [Yin Huai] Add two apply methods which will be used to extract StructField(s) from a StructType. 949d6bb [Yin Huai] When creating a SchemaRDD for a JSON dataset, users can apply an existing schema. 7a6a7e5 [Yin Huai] Fix bug introduced by the change made on SQLContext.inferSchema. 43a45e1 [Yin Huai] Remove sql.util.package introduced in a previous commit. 0266761 [Yin Huai] Format 03eec4c [Yin Huai] Merge remote-tracking branch 'upstream/master' into dataTypeAndSchema 90460ac [Yin Huai] Infer the Catalyst data type from an object and cast a data value to the expected type. 3fa0df5 [Yin Huai] Provide easier ways to construct a StructType. 16be3e5 [Yin Huai] This commit contains three changes: * Expose `DataType`s in the sql package (internal details are private to sql). * Introduce `createSchemaRDD` to create a `SchemaRDD` from an `RDD` with a provided schema (represented by a `StructType`) and a provided function to construct `Row`, * Add a function `simpleString` to every `DataType`. Also, the schema represented by a `StructType` can be visualized by `printSchema`. --- .../apache/spark/api/python/PythonRDD.scala | 3 +- project/SparkBuild.scala | 2 +- python/pyspark/sql.py | 567 +++++++++++++++++- .../spark/sql/catalyst/ScalaReflection.scala | 20 + .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Row.scala | 10 + .../catalyst/expressions/WrapDynamic.scala | 15 +- .../catalyst/expressions/complexTypes.scala | 4 +- .../sql/catalyst/expressions/generators.scala | 8 +- .../apache/spark/sql/catalyst/package.scala | 2 + .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 3 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 45 +- .../plans/logical/basicOperators.scala | 2 +- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 5 +- .../spark/sql/catalyst/trees/package.scala | 5 +- .../spark/sql/catalyst/types/dataTypes.scala | 268 +++++++-- .../sql/catalyst/ScalaReflectionSuite.scala | 66 +- .../spark/sql/api/java/types/ArrayType.java | 68 +++ .../spark/sql/api/java/types/BinaryType.java} | 19 +- .../spark/sql/api/java/types/BooleanType.java | 27 + .../spark/sql/api/java/types/ByteType.java | 27 + .../spark/sql/api/java/types/DataType.java | 190 ++++++ .../spark/sql/api/java/types/DecimalType.java | 27 + .../spark/sql/api/java/types/DoubleType.java | 27 + .../spark/sql/api/java/types/FloatType.java | 27 + .../spark/sql/api/java/types/IntegerType.java | 27 + .../spark/sql/api/java/types/LongType.java | 27 + .../spark/sql/api/java/types/MapType.java | 78 +++ .../spark/sql/api/java/types/ShortType.java | 27 + .../spark/sql/api/java/types/StringType.java | 27 + .../spark/sql/api/java/types/StructField.java | 76 +++ .../spark/sql/api/java/types/StructType.java | 59 ++ .../sql/api/java/types/TimestampType.java | 27 + .../sql/api/java/types/package-info.java | 22 + .../org/apache/spark/sql/SQLContext.scala | 230 +++++-- .../org/apache/spark/sql/SchemaRDD.scala | 10 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 12 +- .../spark/sql/api/java/JavaSQLContext.scala | 65 +- .../spark/sql/api/java/JavaSchemaRDD.scala | 7 + .../org/apache/spark/sql/api/java/Row.scala | 59 +- .../org/apache/spark/sql/json/JsonRDD.scala | 118 ++-- .../org/apache/spark/sql/package-info.java | 0 .../scala/org/apache/spark/sql/package.scala | 409 +++++++++++++ .../spark/sql/parquet/ParquetConverter.scala | 8 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/parquet/ParquetTypes.scala | 18 +- .../sql/types/util/DataTypeConversions.scala | 110 ++++ .../sql/api/java/JavaApplySchemaSuite.java | 166 +++++ .../spark/sql/api/java/JavaRowSuite.java | 170 ++++++ .../java/JavaSideDataTypeConversionSuite.java | 150 +++++ .../org/apache/spark/sql/DataTypeSuite.scala | 58 ++ .../scala/org/apache/spark/sql/RowSuite.scala | 46 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 64 +- .../scala/org/apache/spark/sql/TestData.scala | 7 + .../ScalaSideDataTypeConversionSuite.scala | 81 +++ .../org/apache/spark/sql/json/JsonSuite.scala | 198 +++--- .../apache/spark/sql/hive/HiveContext.scala | 9 +- .../spark/sql/hive/HiveInspectors.scala | 5 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 8 +- 61 files changed, 3442 insertions(+), 386 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java rename sql/{catalyst/src/main/scala/org/apache/spark/sql/package.scala => core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java} (59%) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java rename sql/{catalyst => core}/src/main/scala/org/apache/spark/sql/package-info.java (100%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/package.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d8453fb184a3..f551a59ee3fe8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert an RDD of serialized Python dictionaries to Scala Maps + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * It is only used by pyspark.sql. * TODO: Support more Python types. */ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 490fac3cc3646..e2dab0f9f79ea 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -312,7 +312,7 @@ object Unidoc { "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", "mllib.tree.impurity", "mllib.tree.model", "mllib.util" ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), + "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" ) ) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index a6b3277db3266..13f0ed4e35490 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,7 +20,451 @@ from py4j.protocol import Py4JError -__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +__all__ = [ + "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + + +class PrimitiveTypeSingleton(type): + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class StringType(object): + """Spark SQL StringType + + The data type representing string values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "StringType" + + +class BinaryType(object): + """Spark SQL BinaryType + + The data type representing bytearray values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "BinaryType" + + +class BooleanType(object): + """Spark SQL BooleanType + + The data type representing bool values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "BooleanType" + + +class TimestampType(object): + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "TimestampType" + + +class DecimalType(object): + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "DecimalType" + + +class DoubleType(object): + """Spark SQL DoubleType + + The data type representing float values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "DoubleType" + + +class FloatType(object): + """Spark SQL FloatType + + The data type representing single precision floating-point values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "FloatType" + + +class ByteType(object): + """Spark SQL ByteType + + The data type representing int values with 1 singed byte. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "ByteType" + + +class IntegerType(object): + """Spark SQL IntegerType + + The data type representing int values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "IntegerType" + + +class LongType(object): + """Spark SQL LongType + + The data type representing long values. If the any value is beyond the range of + [-9223372036854775808, 9223372036854775807], please use DecimalType. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "LongType" + + +class ShortType(object): + """Spark SQL ShortType + + The data type representing int values with 2 signed bytes. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def __repr__(self): + return "ShortType" + + +class ArrayType(object): + """Spark SQL ArrayType + + The data type representing list values. + An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. + + """ + def __init__(self, elementType, containsNull=False): + """Creates an ArrayType + + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains None values. + + >>> ArrayType(StringType) == ArrayType(StringType, False) + True + >>> ArrayType(StringType, True) == ArrayType(StringType) + False + """ + self.elementType = elementType + self.containsNull = containsNull + + def __repr__(self): + return "ArrayType(" + self.elementType.__repr__() + "," + \ + str(self.containsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.elementType == other.elementType and + self.containsNull == other.containsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class MapType(object): + """Spark SQL MapType + + The data type representing dict values. + A MapType object comprises three fields, + keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool). + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this map has None values. + For values of a MapType column, keys are not allowed to have None values. + + """ + def __init__(self, keyType, valueType, valueContainsNull=True): + """Creates a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains null values. + + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) + True + >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType) + False + """ + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def __repr__(self): + return "MapType(" + self.keyType.__repr__() + "," + \ + self.valueType.__repr__() + "," + \ + str(self.valueContainsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.keyType == other.keyType and + self.valueType == other.valueType and + self.valueContainsNull == other.valueContainsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class StructField(object): + """Spark SQL StructField + + Represents a field in a StructType. + A StructField object comprises three fields, name (a string), dataType (a DataType), + and nullable (a bool). The field of name is the name of a StructField. The field of + dataType specifies the data type of a StructField. + The field of nullable specifies if values of a StructField can contain None values. + + """ + def __init__(self, name, dataType, nullable): + """Creates a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field can be null. + + >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + True + >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + + def __repr__(self): + return "StructField(" + self.name + "," + \ + self.dataType.__repr__() + "," + \ + str(self.nullable).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.name == other.name and + self.dataType == other.dataType and + self.nullable == other.nullable) + + def __ne__(self, other): + return not self.__eq__(other) + + +class StructType(object): + """Spark SQL StructType + + The data type representing namedtuple values. + A StructType object comprises a list of L{StructField}s. + + """ + def __init__(self, fields): + """Creates a StructType + + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def __repr__(self): + return "StructType(List(" + \ + ",".join([field.__repr__() for field in self.fields]) + "))" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.fields == other.fields) + + def __ne__(self, other): + return not self.__eq__(other) + + +def _parse_datatype_list(datatype_list_string): + """Parses a list of comma separated data types.""" + index = 0 + datatype_list = [] + start = 0 + depth = 0 + while index < len(datatype_list_string): + if depth == 0 and datatype_list_string[index] == ",": + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + start = index + 1 + elif datatype_list_string[index] == "(": + depth += 1 + elif datatype_list_string[index] == ")": + depth -= 1 + + index += 1 + + # Handle the last data type + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + return datatype_list + + +def _parse_datatype_string(datatype_string): + """Parses the given data type string. + + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) + ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... return datatype == python_datatype + >>> check_datatype(StringType()) + True + >>> check_datatype(BinaryType()) + True + >>> check_datatype(BooleanType()) + True + >>> check_datatype(TimestampType()) + True + >>> check_datatype(DecimalType()) + True + >>> check_datatype(DoubleType()) + True + >>> check_datatype(FloatType()) + True + >>> check_datatype(ByteType()) + True + >>> check_datatype(IntegerType()) + True + >>> check_datatype(LongType()) + True + >>> check_datatype(ShortType()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False)]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False) + >>> check_datatype(complex_maptype) + True + """ + left_bracket_index = datatype_string.find("(") + if left_bracket_index == -1: + # It is a primitive type. + left_bracket_index = len(datatype_string) + type_or_field = datatype_string[:left_bracket_index] + rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() + if type_or_field == "StringType": + return StringType() + elif type_or_field == "BinaryType": + return BinaryType() + elif type_or_field == "BooleanType": + return BooleanType() + elif type_or_field == "TimestampType": + return TimestampType() + elif type_or_field == "DecimalType": + return DecimalType() + elif type_or_field == "DoubleType": + return DoubleType() + elif type_or_field == "FloatType": + return FloatType() + elif type_or_field == "ByteType": + return ByteType() + elif type_or_field == "IntegerType": + return IntegerType() + elif type_or_field == "LongType": + return LongType() + elif type_or_field == "ShortType": + return ShortType() + elif type_or_field == "ArrayType": + last_comma_index = rest_part.rfind(",") + containsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + containsNull = False + elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": + last_comma_index = rest_part.rfind(",") + valueContainsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + valueContainsNull = False + keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) + return MapType(keyType, valueType, valueContainsNull) + elif type_or_field == "StructField": + first_comma_index = rest_part.find(",") + name = rest_part[:first_comma_index].strip() + last_comma_index = rest_part.rfind(",") + nullable = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + nullable = False + dataType = _parse_datatype_string( + rest_part[first_comma_index+1:last_comma_index].strip()) + return StructField(name, dataType, nullable) + elif type_or_field == "StructType": + # rest_part should be in the format like + # List(StructField(field1,IntegerType,false)). + field_list_string = rest_part[rest_part.find("(")+1:-1] + fields = _parse_datatype_list(field_list_string) + return StructType(fields) class SQLContext: @@ -109,6 +553,40 @@ def inferSchema(self, rdd): srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) return SchemaRDD(srdd, self) + def applySchema(self, rdd, schema): + """Applies the given schema to the given RDD of L{dict}s. + + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT * from table1") + >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, + ... {"field1" : 3, "field2": "row3"}] + True + >>> from datetime import datetime + >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, + ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, + ... "list": [1, 2, 3]}]) + >>> schema = StructType([ + ... StructField("byte", ByteType(), False), + ... StructField("short", ShortType(), False), + ... StructField("float", FloatType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> srdd = sqlCtx.applySchema(rdd, schema).map( + ... lambda x: ( + ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + >>> srdd.collect()[0] + (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + """ + jrdd = self._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) + return SchemaRDD(srdd, self) + def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -139,10 +617,11 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): - """Loads a text file storing one JSON object per line, - returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonFile(self, path, schema=None): + """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -151,8 +630,8 @@ def jsonFile(self, path): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -160,16 +639,45 @@ def jsonFile(self, path): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ - jschema_rdd = self._ssql_ctx.jsonFile(path) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonFile(path) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd): - """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonRDD(self, rdd, schema=None): + """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - >>> srdd = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -177,6 +685,29 @@ def jsonRDD(self, rdd): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ def func(split, iterator): for x in iterator: @@ -186,7 +717,11 @@ def func(split, iterator): keyed = PipelinedRDD(rdd, func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) def sql(self, sqlQuery): @@ -389,6 +924,10 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schema(self): + """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" + return _parse_datatype_string(self._jschema_rdd.schema().toString()) + def schemaString(self): """Returns the output schema in the tree format.""" return self._jschema_rdd.schemaString() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5a55be1e51558..0d26b52a84695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -85,6 +85,26 @@ object ScalaReflection { case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } + def typeOfObject: PartialFunction[Any, DataType] = { + // The data type can be determined without ambiguity. + case obj: BooleanType.JvmType => BooleanType + case obj: BinaryType.JvmType => BinaryType + case obj: StringType.JvmType => StringType + case obj: ByteType.JvmType => ByteType + case obj: ShortType.JvmType => ShortType + case obj: IntegerType.JvmType => IntegerType + case obj: LongType.JvmType => LongType + case obj: FloatType.JvmType => FloatType + case obj: DoubleType.JvmType => DoubleType + case obj: DecimalType.JvmType => DecimalType + case obj: TimestampType.JvmType => TimestampType + case null => NullType + // For other cases, there is no obvious mapping from the type of the given object to a + // Catalyst data type. A user should provide his/her specific rules + // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of + // objects and then compose the user-defined PartialFunction with this one. + } + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index a3ebec8082cbd..f38f99569f207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.Logging - /** * A bound reference points to a specific slot in the input tuple, allowing the actual value * to be retrieved more efficiently. However, since operations like column pruning can change diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 7470cb861b83b..c9a63e201ef60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -32,6 +32,16 @@ object Row { * }}} */ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index e787c59e75723..eb8898900d6a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -21,8 +21,16 @@ import scala.language.dynamics import org.apache.spark.sql.catalyst.types.DataType -case object DynamicType extends DataType +/** + * The data type representing [[DynamicRow]] values. + */ +case object DynamicType extends DataType { + def simpleString: String = "dynamic" +} +/** + * Wrap a [[Row]] as a [[DynamicRow]]. + */ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow @@ -37,6 +45,11 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { } } +/** + * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. + * Since the type of the column is not known at compile time, all attributes are converted to + * strings before being passed to the function. + */ class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) extends GenericRow(values) with Dynamic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 0acb29012f314..72add5e20e8b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -31,8 +31,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def foldable = child.foldable && ordinal.foldable override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { - case ArrayType(dt) => dt - case MapType(_, vt) => vt + case ArrayType(dt, _) => dt + case MapType(_, vt, _) => vt } override lazy val resolved = childrenResolved && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index dd78614754e12..422839dab770d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -84,8 +84,8 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et) => et :: Nil - case MapType(kt,vt) => kt :: vt :: Nil + case ArrayType(et, _) => et :: Nil + case MapType(kt,vt, _) => kt :: vt :: Nil } // TODO: Move this pattern into Generator. @@ -102,10 +102,10 @@ case class Explode(attributeNames: Seq[String], child: Expression) override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { - case ArrayType(_) => + case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) - case MapType(_, _) => + case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 3b3e206055cfc..ca9642954eb27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -24,4 +24,6 @@ package object catalyst { * 2.10.* builds. See SI-6240 for more details. */ protected[catalyst] object ScalaReflectionLock + + protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 67833664b35ae..781ba489b44c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 418f8686bfe5c..bc763a4e06e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.Logging - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7b82e19b2e714..0988b0c6d990c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -125,51 +125,10 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - protected def generateSchemaString(schema: Seq[Attribute]): String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - schema.foreach { attribute => - val name = attribute.name - val dataType = attribute.dataType - dataType match { - case fields: StructType => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(fields: StructType) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(elementType: DataType) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case _ => builder.append(s"$prefix-- $name: $dataType\n") - } - } - - builder.toString() - } - - protected def generateSchemaString( - schema: StructType, - prefix: String, - builder: StringBuilder): StringBuilder = { - schema.fields.foreach { - case StructField(name, fields: StructType, _) => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(fields: StructType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(elementType: DataType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case StructField(name, fieldType: DataType, _) => - builder.append(s"$prefix-- $name: $fieldType\n") - } - - builder - } + def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = generateSchemaString(output) + def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1537de259c5b4..3cb407217c4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case StructType(fields) => StructType(fields.map(f => StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType)) + case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) case otherType => otherType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 1076537bc7602..f8960b3fe7a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e300bdbececbd..6aa407c836aec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql -package catalyst -package rules +package org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d159ecdd5d781..9a28d035a10a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Logger - /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -35,5 +33,6 @@ import org.apache.spark.sql.Logger */ package object trees { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = Logger("catalyst.trees") + protected val logger = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 71808f76d632b..b52ee6d3378a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -45,11 +45,13 @@ object DataType extends RegexParsers { "TimestampType" ^^^ TimestampType protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) } protected lazy val structField: Parser[StructField] = @@ -82,6 +84,21 @@ object DataType extends RegexParsers { case Success(result, _) => result case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } } abstract class DataType { @@ -92,9 +109,13 @@ abstract class DataType { } def isPrimitive: Boolean = false + + def simpleString: String } -case object NullType extends DataType +case object NullType extends DataType { + def simpleString: String = "null" +} object NativeType { def all = Seq( @@ -108,40 +129,45 @@ trait PrimitiveType extends DataType { } abstract class NativeType extends DataType { - type JvmType - @transient val tag: TypeTag[JvmType] - val ordering: Ordering[JvmType] + private[sql] type JvmType + @transient private[sql] val tag: TypeTag[JvmType] + private[sql] val ordering: Ordering[JvmType] - @transient val classTag = ScalaReflectionLock.synchronized { + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } } case object StringType extends NativeType with PrimitiveType { - type JvmType = String - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "string" } case object BinaryType extends DataType with PrimitiveType { - type JvmType = Array[Byte] + private[sql] type JvmType = Array[Byte] + def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { - type JvmType = Boolean - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "boolean" } case object TimestampType extends NativeType { - type JvmType = Timestamp + private[sql] type JvmType = Timestamp - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -150,7 +176,7 @@ abstract class NumericType extends NativeType with PrimitiveType { // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[JvmType] } object NumericType { @@ -166,39 +192,43 @@ object IntegralType { } abstract class IntegralType extends NumericType { - val integral: Integral[JvmType] + private[sql] val integral: Integral[JvmType] } case object LongType extends IntegralType { - type JvmType = Long - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Long]] - val integral = implicitly[Integral[Long]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "long" } case object IntegerType extends IntegralType { - type JvmType = Int - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Int]] - val integral = implicitly[Integral[Int]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "integer" } case object ShortType extends IntegralType { - type JvmType = Short - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Short]] - val integral = implicitly[Integral[Short]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "short" } case object ByteType extends IntegralType { - type JvmType = Byte - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Byte]] - val integral = implicitly[Integral[Byte]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -209,47 +239,159 @@ object FractionalType { } } abstract class FractionalType extends NumericType { - val fractional: Fractional[JvmType] + private[sql] val fractional: Fractional[JvmType] } case object DecimalType extends FractionalType { - type JvmType = BigDecimal - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[BigDecimal]] - val fractional = implicitly[Fractional[BigDecimal]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = BigDecimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[BigDecimal]] + private[sql] val fractional = implicitly[Fractional[BigDecimal]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "decimal" } case object DoubleType extends FractionalType { - type JvmType = Double - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Double]] - val fractional = implicitly[Fractional[Double]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "double" } case object FloatType extends FractionalType { - type JvmType = Float - @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - val numeric = implicitly[Numeric[Float]] - val fractional = implicitly[Fractional[Float]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "float" } -case class ArrayType(elementType: DataType) extends DataType +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) +} -case class StructField(name: String, dataType: DataType, nullable: Boolean) +/** + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + */ +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + def simpleString: String = "array" +} + +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + */ +case class StructField(name: String, dataType: DataType, nullable: Boolean) { + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } +} object StructType { - def fromAttributes(attributes: Seq[Attribute]): StructType = { + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - } - // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) + private def validateFields(fields: Seq[StructField]): Boolean = + fields.map(field => field.name).distinct.size == fields.size } case class StructType(fields: Seq[StructField]) extends DataType { - def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + require(StructType.validateFields(fields), "Found fields with the same name.") + + /** + * Returns all field names in a [[Seq]]. + */ + lazy val fieldNames: Seq[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.get(name).getOrElse( + throw new IllegalArgumentException(s"Field ${name} does not exist.")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names. + * Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (!nonExistFields.isEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + protected[sql] def toAttributes = + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + def simpleString: String = "struct" +} + +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, true) } -case class MapType(keyType: DataType, valueType: DataType) extends DataType +/** + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString} " + + s"(valueContainsNull = ${valueContainsNull})\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + def simpleString: String = "map" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index c0438dbe52a47..e030d6e13d472 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst +import java.math.BigInteger import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite { StructField("_2", StringType, nullable = true))), nullable = true)) } + + test("get data type of a value") { + // BooleanType + assert(BooleanType === typeOfObject(true)) + assert(BooleanType === typeOfObject(false)) + + // BinaryType + assert(BinaryType === typeOfObject("string".getBytes)) + + // StringType + assert(StringType === typeOfObject("string")) + + // ByteType + assert(ByteType === typeOfObject(127.toByte)) + + // ShortType + assert(ShortType === typeOfObject(32767.toShort)) + + // IntegerType + assert(IntegerType === typeOfObject(2147483647)) + + // LongType + assert(LongType === typeOfObject(9223372036854775807L)) + + // FloatType + assert(FloatType === typeOfObject(3.4028235E38.toFloat)) + + // DoubleType + assert(DoubleType === typeOfObject(1.7976931348623157E308)) + + // DecimalType + assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + + // TimestampType + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00"))) + + // NullType + assert(NullType === typeOfObject(null)) + + def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + case value: java.math.BigDecimal => DecimalType + case _ => StringType + } + + assert(DecimalType === typeOfObject1( + new BigInteger("92233720368547758070"))) + assert(DecimalType === typeOfObject1( + new java.math.BigDecimal("1.7976931348623157E318"))) + assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) + + def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + } + + intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) + + def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { + case c: Seq[_] => ArrayType(typeOfObject3(c.head)) + } + + assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java new file mode 100644 index 0000000000000..17334ca31b2b7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Lists. + * An ArrayType object comprises two fields, {@code DataType elementType} and + * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of + * array elements. The field of {@code containsNull} is used to specify if the array has + * {@code null} values. + * + * To create an {@link ArrayType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType, boolean)} + * should be used. + */ +public class ArrayType extends DataType { + private DataType elementType; + private boolean containsNull; + + protected ArrayType(DataType elementType, boolean containsNull) { + this.elementType = elementType; + this.containsNull = containsNull; + } + + public DataType getElementType() { + return elementType; + } + + public boolean isContainsNull() { + return containsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ArrayType arrayType = (ArrayType) o; + + if (containsNull != arrayType.containsNull) return false; + if (!elementType.equals(arrayType.elementType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = elementType.hashCode(); + result = 31 * result + (containsNull ? 1 : 0); + return result; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java similarity index 59% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala rename to sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java index 4589129cd1c90..61703179850e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java @@ -15,22 +15,13 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.sql.api.java.types; /** - * Allows the execution of relational queries, including those expressed in SQL using Spark. + * The data type representing byte[] values. * - * Note that this package is located in catalyst instead of in core so that all subprojects can - * inherit the settings from this package object. + * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}. */ -package object sql { - - protected[sql] def Logger(name: String) = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name)) - - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - - type Row = catalyst.expressions.Row - - val Row = catalyst.expressions.Row +public class BinaryType extends DataType { + protected BinaryType() {} } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java new file mode 100644 index 0000000000000..8fa24d85d1238 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing boolean and Boolean values. + * + * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. + */ +public class BooleanType extends DataType { + protected BooleanType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java new file mode 100644 index 0000000000000..2de32978e2705 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing byte and Byte values. + * + * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}. + */ +public class ByteType extends DataType { + protected ByteType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java new file mode 100644 index 0000000000000..f84e5a490a905 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The base type of all Spark SQL data types. + * + * To get/create specific data type, users should use singleton objects and factory methods + * provided by this class. + */ +public abstract class DataType { + + /** + * Gets the StringType object. + */ + public static final StringType StringType = new StringType(); + + /** + * Gets the BinaryType object. + */ + public static final BinaryType BinaryType = new BinaryType(); + + /** + * Gets the BooleanType object. + */ + public static final BooleanType BooleanType = new BooleanType(); + + /** + * Gets the TimestampType object. + */ + public static final TimestampType TimestampType = new TimestampType(); + + /** + * Gets the DecimalType object. + */ + public static final DecimalType DecimalType = new DecimalType(); + + /** + * Gets the DoubleType object. + */ + public static final DoubleType DoubleType = new DoubleType(); + + /** + * Gets the FloatType object. + */ + public static final FloatType FloatType = new FloatType(); + + /** + * Gets the ByteType object. + */ + public static final ByteType ByteType = new ByteType(); + + /** + * Gets the IntegerType object. + */ + public static final IntegerType IntegerType = new IntegerType(); + + /** + * Gets the LongType object. + */ + public static final LongType LongType = new LongType(); + + /** + * Gets the ShortType object. + */ + public static final ShortType ShortType = new ShortType(); + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}). + * The field of {@code containsNull} is set to {@code false}. + */ + public static ArrayType createArrayType(DataType elementType) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, false); + } + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and + * whether the array contains null values ({@code containsNull}). + */ + public static ArrayType createArrayType(DataType elementType, boolean containsNull) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, containsNull); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}) and values + * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. + */ + public static MapType createMapType(DataType keyType, DataType valueType) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, true); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of + * values ({@code keyType}), and whether values contain any null value + * ({@code valueContainsNull}). + */ + public static MapType createMapType( + DataType keyType, + DataType valueType, + boolean valueContainsNull) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, valueContainsNull); + } + + /** + * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and + * whether values of this field can be null values ({@code nullable}). + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + if (name == null) { + throw new IllegalArgumentException("name should not be null."); + } + if (dataType == null) { + throw new IllegalArgumentException("dataType should not be null."); + } + + return new StructField(name, dataType, nullable); + } + + /** + * Creates a StructType with the given list of StructFields ({@code fields}). + */ + public static StructType createStructType(List fields) { + return createStructType(fields.toArray(new StructField[0])); + } + + /** + * Creates a StructType with the given StructField array ({@code fields}). + */ + public static StructType createStructType(StructField[] fields) { + if (fields == null) { + throw new IllegalArgumentException("fields should not be null."); + } + Set distinctNames = new HashSet(); + for (StructField field: fields) { + if (field == null) { + throw new IllegalArgumentException( + "fields should not contain any null."); + } + + distinctNames.add(field.getName()); + } + if (distinctNames.size() != fields.length) { + throw new IllegalArgumentException("fields should have distinct names."); + } + + return new StructType(fields); + } + +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java new file mode 100644 index 0000000000000..9250491a2d2ca --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing java.math.BigDecimal values. + * + * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. + */ +public class DecimalType extends DataType { + protected DecimalType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java new file mode 100644 index 0000000000000..3e86917fddc4b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing double and Double values. + * + * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}. + */ +public class DoubleType extends DataType { + protected DoubleType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java new file mode 100644 index 0000000000000..fa860d40176ef --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing float and Float values. + * + * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}. + */ +public class FloatType extends DataType { + protected FloatType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java new file mode 100644 index 0000000000000..bd973eca2c3ce --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing int and Integer values. + * + * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}. + */ +public class IntegerType extends DataType { + protected IntegerType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java new file mode 100644 index 0000000000000..e00233304cefa --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing long and Long values. + * + * {@code LongType} is represented by the singleton object {@link DataType#LongType}. + */ +public class LongType extends DataType { + protected LongType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java new file mode 100644 index 0000000000000..94936e2e4ee7a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Maps. A MapType object comprises two fields, + * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}. + * The field of {@code keyType} is used to specify the type of keys in the map. + * The field of {@code valueType} is used to specify the type of values in the map. + * The field of {@code valueContainsNull} is used to specify if map values have + * {@code null} values. + * For values of a MapType column, keys are not allowed to have {@code null} values. + * + * To create a {@link MapType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType, boolean)} + * should be used. + */ +public class MapType extends DataType { + private DataType keyType; + private DataType valueType; + private boolean valueContainsNull; + + protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { + this.keyType = keyType; + this.valueType = valueType; + this.valueContainsNull = valueContainsNull; + } + + public DataType getKeyType() { + return keyType; + } + + public DataType getValueType() { + return valueType; + } + + public boolean isValueContainsNull() { + return valueContainsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MapType mapType = (MapType) o; + + if (valueContainsNull != mapType.valueContainsNull) return false; + if (!keyType.equals(mapType.keyType)) return false; + if (!valueType.equals(mapType.valueType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = keyType.hashCode(); + result = 31 * result + valueType.hashCode(); + result = 31 * result + (valueContainsNull ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java new file mode 100644 index 0000000000000..98f9507acf121 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing short and Short values. + * + * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}. + */ +public class ShortType extends DataType { + protected ShortType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java new file mode 100644 index 0000000000000..b8e7dbe646071 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing String values. + * + * {@code StringType} is represented by the singleton object {@link DataType#StringType}. + */ +public class StringType extends DataType { + protected StringType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java new file mode 100644 index 0000000000000..54e9c11ea415e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * A StructField object represents a field in a StructType object. + * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, + * and {@code boolean nullable}. The field of {@code name} is the name of a StructField. + * The field of {@code dataType} specifies the data type of a StructField. + * The field of {@code nullable} specifies if values of a StructField can contain {@code null} + * values. + * + * To create a {@link StructField}, + * {@link org.apache.spark.sql.api.java.types.DataType#createStructField(String, DataType, boolean)} + * should be used. + */ +public class StructField { + private String name; + private DataType dataType; + private boolean nullable; + + protected StructField(String name, DataType dataType, boolean nullable) { + this.name = name; + this.dataType = dataType; + this.nullable = nullable; + } + + public String getName() { + return name; + } + + public DataType getDataType() { + return dataType; + } + + public boolean isNullable() { + return nullable; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructField that = (StructField) o; + + if (nullable != that.nullable) return false; + if (!dataType.equals(that.dataType)) return false; + if (!name.equals(that.name)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + dataType.hashCode(); + result = 31 * result + (nullable ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java new file mode 100644 index 0000000000000..33a42f4b16265 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +import java.util.Arrays; +import java.util.List; + +/** + * The data type representing Rows. + * A StructType object comprises an array of StructFields. + * + * To create an {@link StructType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(java.util.List)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(StructField[])} + * should be used. + */ +public class StructType extends DataType { + private StructField[] fields; + + protected StructType(StructField[] fields) { + this.fields = fields; + } + + public StructField[] getFields() { + return fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructType that = (StructType) o; + + if (!Arrays.equals(fields, that.fields)) return false; + + return true; + } + + @Override + public int hashCode() { + return Arrays.hashCode(fields); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java new file mode 100644 index 0000000000000..65295779f71ec --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing java.sql.Timestamp values. + * + * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}. + */ +public class TimestampType extends DataType { + protected TimestampType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java new file mode 100644 index 0000000000000..f169ac65e226f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Allows users to get and create Spark SQL data types. + */ +package org.apache.spark.sql.api.java.types; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e4b6810180994..86338752a21c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -88,6 +87,44 @@ class SQLContext(@transient val sparkContext: SparkContext) implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + /** + * :: DeveloperApi :: + * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val peopleSchemaRDD = sqlContext. applySchema(people, schema) + * peopleSchemaRDD.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * peopleSchemaRDD.registerAsTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + * + * @group userf + */ + @DeveloperApi + def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + new SchemaRDD(this, logicalPlan) + } + /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * @@ -104,6 +141,19 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + @Experimental + def jsonFile(path: String, schema: StructType): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, schema) + } + /** * :: Experimental :: */ @@ -122,12 +172,30 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + @Experimental + def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val appliedSchema = + Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } + /** * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = - new SchemaRDD(this, JsonRDD.inferSchema(self, json, samplingRatio)) + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) + } /** * :: Experimental :: @@ -345,70 +413,138 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Peek at the first row of the RDD and infer its schema. - * TODO: consolidate this with the type system developed in SPARK-2060. + * It is only used by PySpark. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ - def typeFor(obj: Any): DataType = obj match { - case c: java.lang.String => StringType - case c: java.lang.Integer => IntegerType - case c: java.lang.Long => LongType - case c: java.lang.Double => DoubleType - case c: java.lang.Boolean => BooleanType - case c: java.math.BigDecimal => DecimalType - case c: java.sql.Timestamp => TimestampType + + def typeOfComplexValue: PartialFunction[Any, DataType] = { case c: java.util.Calendar => TimestampType - case c: java.util.List[_] => ArrayType(typeFor(c.head)) + case c: java.util.List[_] => + ArrayType(typeOfObject(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head - MapType(typeFor(key), typeFor(value)) + MapType(typeOfObject(key), typeOfObject(value)) case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeFor(elem)) + ArrayType(typeOfObject(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } + def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue + val firstRow = rdd.first() - val schema = firstRow.map { case (fieldName, obj) => - AttributeReference(fieldName, typeFor(obj), true)() + val fields = firstRow.map { + case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) }.toSeq - def needTransform(obj: Any): Boolean = obj match { - case c: java.util.List[_] => true - case c: java.util.Map[_, _] => true - case c if c.getClass.isArray => true - case c: java.util.Calendar => true - case c => false + applySchemaToPythonRDD(rdd, StructType(fields)) + } + + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schema: StructType): SchemaRDD = { + // TODO: We should have a better implementation once we do not turn a Python side record + // to a Map. + import scala.collection.JavaConversions._ + import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} + + def needsConversion(dataType: DataType): Boolean = dataType match { + case ByteType => true + case ShortType => true + case FloatType => true + case TimestampType => true + case ArrayType(_, _) => true + case MapType(_, _, _) => true + case StructType(_) => true + case other => false } - // convert JList, JArray into Seq, convert JMap into Map - // convert Calendar into Timestamp - def transform(obj: Any): Any = obj match { - case c: java.util.List[_] => c.map(transform).toSeq - case c: java.util.Map[_, _] => c.map { - case (key, value) => (key, transform(value)) - }.toMap - case c if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(transform).toSeq - case c: java.util.Calendar => - new java.sql.Timestamp(c.getTime().getTime()) - case c => c + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + val converted = c.map { e => convert(e, elementType)} + JListWrapper(converted) + + case (c: java.util.Map[_, _], struct: StructType) => + val row = new GenericMutableRow(struct.fields.length) + struct.fields.zipWithIndex.foreach { + case (field, i) => + val value = convert(c.get(field.name), field.dataType) + row.update(i, value) + } + row + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val converted = c.map { + case (key, value) => + (convert(key, keyType), convert(value, valueType)) + } + JMapWrapper(converted) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) + converted: Seq[Any] + + case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (c: Int, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Double, FloatType) => c.toFloat + + case (c, _) => c + } + + val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { + rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + } else { + rdd } - val need = firstRow.exists {case (key, value) => needTransform(value)} - val transformed = if (need) { - rdd.mapPartitions { iter => - iter.map { - m => m.map {case (key, value) => (key, transform(value))} + val rowRdd = convertedRdd.mapPartitions { iter => + val row = new GenericMutableRow(schema.fields.length) + val fieldsWithIndex = schema.fields.zipWithIndex + iter.map { m => + // We cannot use m.values because the order of values returned by m.values may not + // match fields order. + fieldsWithIndex.foreach { + case (field, i) => + val value = + m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull + row.update(i, value) } - } - } else rdd - val rowRdd = transformed.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + row: Row } } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(self)) - } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 172b6e0e7f26b..420f21fb9c1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList, Set => JSet} +import java.util.{Map => JMap, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -120,6 +119,11 @@ class SchemaRDD( override protected def getDependencies: Seq[Dependency[_]] = List(new OneToOneDependency(queryExecution.toRdd)) + /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + def schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -376,6 +380,8 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + import scala.collection.Map + def toJava(obj: Any, dataType: DataType): Any = dataType match { case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) case array: ArrayType => obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fd751031b26e5..6a20def475822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -123,9 +123,15 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd - /** Returns the output schema in the tree format. */ - def schemaString: String = queryExecution.analyzed.schemaString + /** Returns the schema as a string in the tree format. + * + * @group schema + */ + def schemaString: String = baseSchemaRDD.schema.treeString - /** Prints out the schema in the tree format. */ + /** Prints out the schema. + * + * @group schema + */ def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 85726bae54911..c1c18a0cd0ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -21,14 +21,16 @@ import java.beans.Introspector import org.apache.hadoop.conf.Configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.api.java.types.{StructType => JStructType} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.types.util.DataTypeConversions +import DataTypeConversions.asScalaDataType; import org.apache.spark.util.Utils /** @@ -95,6 +97,21 @@ class JavaSQLContext(val sqlContext: SQLContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) } + /** + * :: DeveloperApi :: + * Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD. + * It is important to make sure that the structure of every Row of the provided RDD matches the + * provided schema. Otherwise, there will be runtime exception. + */ + @DeveloperApi + def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = { + val scalaRowRDD = rowRDD.rdd.map(r => r.row) + val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType] + val logicalPlan = + SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ @@ -104,23 +121,49 @@ class JavaSQLContext(val sqlContext: SQLContext) { ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** - * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ def jsonFile(path: String): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path)) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonFile(path: String, schema: JStructType): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path), schema) + /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[JavaSchemaRDD]]. + * JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ - def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(sqlContext, json, 1.0)) + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { + val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = { + val appliedScalaSchema = + Option(asScalaDataType(schema)).getOrElse( + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType] + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = + SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + new JavaSchemaRDD(sqlContext, logicalPlan) + } /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 8fbf13b8b0150..824574149858c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,8 +22,11 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.sql.api.java.types.StructType +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import DataTypeConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -53,6 +56,10 @@ class JavaSchemaRDD( override def toString: String = baseSchemaRDD.toString + /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ + def schema: StructType = + asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 9b0dd2176149b..6c67934bda5b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.api.java +import scala.annotation.varargs +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} +import scala.collection.JavaConversions +import scala.math.BigDecimal + import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -29,7 +34,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { /** Returns the value of column `i`. */ def get(i: Int): Any = - row(i) + Row.toJavaValue(row(i)) /** Returns true if value at column `i` is NULL. */ def isNullAt(i: Int) = get(i) == null @@ -89,5 +94,57 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { */ def getString(i: Int): String = row.getString(i) + + def canEqual(other: Any): Boolean = other.isInstanceOf[Row] + + override def equals(other: Any): Boolean = other match { + case that: Row => + (that canEqual this) && + row == that.row + case _ => false + } + + override def hashCode(): Int = row.hashCode() } +object Row { + + private def toJavaValue(value: Any): Any = value match { + // For values of this ScalaRow, we will do the conversion when + // they are actually accessed. + case row: ScalaRow => new Row(row) + case map: scala.collection.Map[_, _] => + JavaConversions.mapAsJavaMap( + map.map { + case (key, value) => (toJavaValue(key), toJavaValue(value)) + } + ) + case seq: scala.collection.Seq[_] => + JavaConversions.seqAsJavaList(seq.map(toJavaValue)) + case decimal: BigDecimal => decimal.underlying() + case other => other + } + + // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD? + private def toScalaValue(value: Any): Any = value match { + // Values of this row have been converted to Scala values. + case row: Row => row.row + case map: java.util.Map[_, _] => + JMapWrapper(map).map { + case (key, value) => (toScalaValue(key), toScalaValue(value)) + } + case list: java.util.List[_] => + JListWrapper(list).map(toScalaValue) + case decimal: java.math.BigDecimal => BigDecimal(decimal) + case other => other + } + + /** + * Creates a Row with the given values. + */ + @varargs def create(values: Any*): Row = { + // Right now, we cannot use @varargs to annotate the constructor of + // org.apache.spark.sql.api.java.Row. See https://issues.scala-lang.org/browse/SI-8383. + new Row(ScalaRow(values.map(toScalaValue):_*)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 6c2b553bb908e..bd29ee421bbc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -25,33 +25,25 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.{SQLContext, Logging} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { + private[sql] def jsonStringToRow( + json: RDD[String], + schema: StructType): RDD[Row] = { + parseJson(json).map(parsed => asRow(parsed, schema)) + } + private[sql] def inferSchema( - sqlContext: SQLContext, json: RDD[String], - samplingRatio: Double = 1.0): LogicalPlan = { + samplingRatio: Double = 1.0): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) - val baseSchema = createSchema(allKeys) - - createLogicalPlan(json, baseSchema, sqlContext) - } - - private def createLogicalPlan( - json: RDD[String], - baseSchema: StructType, - sqlContext: SQLContext): LogicalPlan = { - val schema = nullTypeToStringType(baseSchema) - - SparkLogicalPlan( - ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema))))(sqlContext) + createSchema(allKeys) } private def createSchema(allKeys: Set[(String, DataType)]): StructType = { @@ -75,8 +67,8 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil)) => false - case ArrayType(_) => true + case ArrayType(StructType(Nil), _) => false + case ArrayType(_, _) => true case struct: StructType => false case _ => true } @@ -90,7 +82,8 @@ private[sql] object JsonRDD extends Logging { val structType = makeStruct(nestedFields, prefix :+ name) val dataType = resolved.get(prefix :+ name).get dataType match { - case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case array: ArrayType => + Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -109,6 +102,22 @@ private[sql] object JsonRDD extends Logging { makeStruct(resolved.keySet.toSeq, Nil) } + private[sql] def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + /** * Returns the most general data type for two given data types. */ @@ -139,8 +148,8 @@ private[sql] object JsonRDD extends Logging { case StructField(name, _, _) => name }) } - case (ArrayType(elementType1), ArrayType(elementType2)) => - ArrayType(compatibleType(elementType1, elementType2)) + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // TODO: We should use JsonObjectStringType to mark that values of field will be // strings and every string is a Json object. case (_, _) => StringType @@ -148,18 +157,13 @@ private[sql] object JsonRDD extends Logging { } } - private def typeOfPrimitiveValue(value: Any): DataType = { - value match { - case value: java.lang.String => StringType - case value: java.lang.Integer => IntegerType - case value: java.lang.Long => LongType + private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { + ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType - case value: java.lang.Double => DoubleType + // DecimalType's JVMType is scala BigDecimal. case value: java.math.BigDecimal => DecimalType - case value: java.lang.Boolean => BooleanType - case null => NullType // Unexpected data type. case _ => StringType } @@ -172,12 +176,13 @@ private[sql] object JsonRDD extends Logging { * treat the element as String. */ private def typeOfArray(l: Seq[Any]): ArrayType = { + val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType) + ArrayType(NullType, containsNull) } else { val elementType = elements.map { e => e match { @@ -189,7 +194,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType) + ArrayType(elementType, containsNull) } } @@ -216,15 +221,16 @@ private[sql] object JsonRDD extends Logging { case (key: String, array: Seq[_]) => { // The value associated with the key is an array. typeOfArray(array) match { - case ArrayType(StructType(Nil)) => { + case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil))) + } :+ (key, ArrayType(StructType(Nil), containsNull)) } - case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + case ArrayType(elementType, containsNull) => + (key, ArrayType(elementType, containsNull)) :: Nil } } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil @@ -262,8 +268,11 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) - }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) + iter.map { record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] + } + }) } private def toLong(value: Any): Long = { @@ -334,7 +343,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType) => + case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] @@ -348,6 +357,7 @@ private[sql] object JsonRDD extends Logging { } private def asRow(json: Map[String,Any], schema: StructType): Row = { + // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { // StructType @@ -356,7 +366,7 @@ private[sql] object JsonRDD extends Logging { v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType), _), i) => + case (StructField(name, ArrayType(structType: StructType, _), _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( v => v.asInstanceOf[Seq[Any]].map( @@ -370,32 +380,4 @@ private[sql] object JsonRDD extends Logging { row } - - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType) => ArrayType(StringType) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - private def asAttributes(struct: StructType): Seq[AttributeReference] = { - struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) - } - - private def asStruct(attributes: Seq[AttributeReference]): StructType = { - val fields = attributes.map { - case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) - } - - StructType(fields) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/package-info.java similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/scala/org/apache/spark/sql/package-info.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000000..0995a4eb6299f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * @groupname dataType Data types + * @groupdesc Spark SQL data types. + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 + * @groupname row Row + * @groupprio row -1 + */ +package object sql { + + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + + /** + * :: DeveloperApi :: + * + * Represents one row of output from a relational operator. + * @group row + */ + @DeveloperApi + type Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * A [[Row]] object can be constructed by providing field values. Example: + * {{{ + * import org.apache.spark.sql._ + * + * // Create a Row from values. + * Row(value1, value2, value3, ...) + * // Create a Row from a Seq of values. + * Row.fromSeq(Seq(value1, value2, ...)) + * }}} + * + * A value of a row can be accessed through both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * An example of generic access by ordinal: + * {{{ + * import org.apache.spark.sql._ + * + * val row = Row(1, true, "a string", null) + * // row: Row = [1,true,a string,null] + * val firstValue = row(0) + * // firstValue: Any = 1 + * val fourthValue = row(3) + * // fourthValue: Any = null + * }}} + * + * For native primitive access, it is invalid to use the native primitive interface to retrieve + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a + * value that might be null. + * An example of native primitive access: + * {{{ + * // using the row from the previous example. + * val firstValue = row.getInt(0) + * // firstValue: Int = 1 + * val isNull = row.isNullAt(3) + * // isNull: Boolean = true + * }}} + * + * Interfaces related to native primitive access are: + * + * `isNullAt(i: Int): Boolean` + * + * `getInt(i: Int): Int` + * + * `getLong(i: Int): Long` + * + * `getDouble(i: Int): Double` + * + * `getFloat(i: Int): Float` + * + * `getBoolean(i: Int): Boolean` + * + * `getShort(i: Int): Short` + * + * `getByte(i: Int): Byte` + * + * `getString(i: Int): String` + * + * Fields in a [[Row]] object can be extracted in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + * + * @group row + */ + @DeveloperApi + val Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ + @DeveloperApi + type DataType = catalyst.types.DataType + + /** + * :: DeveloperApi :: + * + * The data type representing `String` values + * + * @group dataType + */ + @DeveloperApi + val StringType = catalyst.types.StringType + + /** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * + * @group dataType + */ + @DeveloperApi + val BinaryType = catalyst.types.BinaryType + + /** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. + * + *@group dataType + */ + @DeveloperApi + val BooleanType = catalyst.types.BooleanType + + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values. + * + * @group dataType + */ + @DeveloperApi + val TimestampType = catalyst.types.TimestampType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * @group dataType + */ + @DeveloperApi + val DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `Double` values. + * + * @group dataType + */ + @DeveloperApi + val DoubleType = catalyst.types.DoubleType + + /** + * :: DeveloperApi :: + * + * The data type representing `Float` values. + * + * @group dataType + */ + @DeveloperApi + val FloatType = catalyst.types.FloatType + + /** + * :: DeveloperApi :: + * + * The data type representing `Byte` values. + * + * @group dataType + */ + @DeveloperApi + val ByteType = catalyst.types.ByteType + + /** + * :: DeveloperApi :: + * + * The data type representing `Int` values. + * + * @group dataType + */ + @DeveloperApi + val IntegerType = catalyst.types.IntegerType + + /** + * :: DeveloperApi :: + * + * The data type representing `Long` values. + * + * @group dataType + */ + @DeveloperApi + val LongType = catalyst.types.LongType + + /** + * :: DeveloperApi :: + * + * The data type representing `Short` values. + * + * @group dataType + */ + @DeveloperApi + val ShortType = catalyst.types.ShortType + + /** + * :: DeveloperApi :: + * + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @group dataType + */ + @DeveloperApi + type ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * An [[ArrayType]] object can be constructed with two ways, + * {{{ + * ArrayType(elementType: DataType, containsNull: Boolean) + * }}} and + * {{{ + * ArrayType(elementType: DataType) + * }}} + * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`. + * + * @group dataType + */ + @DeveloperApi + val ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * The data type representing `Map`s. A [[MapType]] object comprises three fields, + * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`. + * The field of `keyType` is used to specify the type of keys in the map. + * The field of `valueType` is used to specify the type of values in the map. + * The field of `valueContainsNull` is used to specify if values of this map has `null` values. + * For values of a MapType column, keys are not allowed to have `null` values. + * + * @group dataType + */ + @DeveloperApi + type MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * A [[MapType]] object can be constructed with two ways, + * {{{ + * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + * }}} and + * {{{ + * MapType(keyType: DataType, valueType: DataType) + * }}} + * For `MapType(keyType: DataType, valueType: DataType)`, + * the field of `valueContainsNull` is set to `true`. + * + * @group dataType + */ + @DeveloperApi + val MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * The data type representing [[Row]]s. + * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s. + * + * @group dataType + */ + @DeveloperApi + type StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Those names do not have matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object represents a field in a [[StructType]] object. + * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`, + * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of + * `dataType` specifies the data type of a `StructField`. + * The field of `nullable` specifies if values of a `StructField` can contain `null` values. + * + * @group field + */ + @DeveloperApi + type StructField = catalyst.types.StructField + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object can be constructed by + * {{{ + * StructField(name: String, dataType: DataType, nullable: Boolean) + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructField = catalyst.types.StructField +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index de8fe2dae38f6..0a3b59cbc233a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -75,21 +75,21 @@ private[sql] object CatalystConverter { val fieldType: DataType = field.dataType fieldType match { // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType) => { + case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType) => { + case ArrayType(elementType: DataType, false) => { new CatalystArrayConverter(elementType, fieldIndex, parent) } case StructType(fields: Seq[StructField]) => { new CatalystStructConverter(fields.toArray, fieldIndex, parent) } - case MapType(keyType: DataType, valueType: DataType) => { + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { new CatalystMapConverter( Array( new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), fieldIndex, parent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 39294a3f4bf5a..6d4ce32ac5bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -172,10 +172,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case t @ ArrayType(_) => writeArray( + case t @ ArrayType(_, false) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) - case t @ MapType(_, _) => writeMap( + case t @ MapType(_, _, _) => writeMap( t, value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) case t @ StructType(_) => writeStruct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 58370b955a5ec..aaef1a1d474fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - new ArrayType(toDataType(field)) + ArrayType(toDataType(field), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -130,7 +130,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } case _ => { // Note: the order of these checks is important! @@ -140,10 +142,12 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. + MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) - new ArrayType(elementType) + ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields @@ -151,7 +155,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ptype.getName, toDataType(ptype), ptype.getRepetition != Repetition.REQUIRED)) - new StructType(fields) + StructType(fields) } } } @@ -234,7 +238,7 @@ private[parquet] object ParquetTypesConverter extends Logging { new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) }.getOrElse { ctype match { - case ArrayType(elementType) => { + case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, @@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } new ParquetGroupType(repetition, name, fields) } - case MapType(keyType, valueType) => { + case MapType(keyType, valueType, _) => { val parquetKeyType = fromDataType( keyType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala new file mode 100644 index 0000000000000..d1aa3c8d53757 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.util + +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} + +import scala.collection.JavaConverters._ + +protected[sql] object DataTypeConversions { + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asJavaStructField(scalaStructField: StructField): JStructField = { + JDataType.createStructField( + scalaStructField.name, + asJavaDataType(scalaStructField.dataType), + scalaStructField.nullable) + } + + /** + * Returns the equivalent DataType in Java for the given DataType in Scala. + */ + def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case StringType => JDataType.StringType + case BinaryType => JDataType.BinaryType + case BooleanType => JDataType.BooleanType + case TimestampType => JDataType.TimestampType + case DecimalType => JDataType.DecimalType + case DoubleType => JDataType.DoubleType + case FloatType => JDataType.FloatType + case ByteType => JDataType.ByteType + case IntegerType => JDataType.IntegerType + case LongType => JDataType.LongType + case ShortType => JDataType.ShortType + + case arrayType: ArrayType => JDataType.createArrayType( + asJavaDataType(arrayType.elementType), arrayType.containsNull) + case mapType: MapType => JDataType.createMapType( + asJavaDataType(mapType.keyType), + asJavaDataType(mapType.valueType), + mapType.valueContainsNull) + case structType: StructType => JDataType.createStructType( + structType.fields.map(asJavaStructField).asJava) + } + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asScalaStructField(javaStructField: JStructField): StructField = { + StructField( + javaStructField.getName, + asScalaDataType(javaStructField.getDataType), + javaStructField.isNullable) + } + + /** + * Returns the equivalent DataType in Scala for the given DataType in Java. + */ + def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case stringType: org.apache.spark.sql.api.java.types.StringType => + StringType + case binaryType: org.apache.spark.sql.api.java.types.BinaryType => + BinaryType + case booleanType: org.apache.spark.sql.api.java.types.BooleanType => + BooleanType + case timestampType: org.apache.spark.sql.api.java.types.TimestampType => + TimestampType + case decimalType: org.apache.spark.sql.api.java.types.DecimalType => + DecimalType + case doubleType: org.apache.spark.sql.api.java.types.DoubleType => + DoubleType + case floatType: org.apache.spark.sql.api.java.types.FloatType => + FloatType + case byteType: org.apache.spark.sql.api.java.types.ByteType => + ByteType + case integerType: org.apache.spark.sql.api.java.types.IntegerType => + IntegerType + case longType: org.apache.spark.sql.api.java.types.LongType => + LongType + case shortType: org.apache.spark.sql.api.java.types.ShortType => + ShortType + + case arrayType: org.apache.spark.sql.api.java.types.ArrayType => + ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) + case mapType: org.apache.spark.sql.api.java.types.MapType => + MapType( + asScalaDataType(mapType.getKeyType), + asScalaDataType(mapType.getValueType), + mapType.isValueContainsNull) + case structType: org.apache.spark.sql.api.java.types.StructType => + StructType(structType.getFields.map(asScalaStructField)) + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java new file mode 100644 index 0000000000000..8ee4591105010 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.api.java.types.DataType; +import org.apache.spark.sql.api.java.types.StructField; +import org.apache.spark.sql.api.java.types.StructType; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return Row.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("name", DataType.StringType, false)); + fields.add(DataType.createStructField("age", DataType.IntegerType, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); + schemaRDD.registerAsTable("people"); + List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(Row.create("Michael", 29)); + expected.add(Row.create("Yin", 28)); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); + fields.add(DataType.createStructField("double", DataType.DoubleType, true)); + fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); + fields.add(DataType.createStructField("long", DataType.LongType, true)); + fields.add(DataType.createStructField("null", DataType.StringType, true)); + fields.add(DataType.createStructField("string", DataType.StringType, true)); + StructType expectedSchema = DataType.createStructType(fields); + List expectedResult = new ArrayList(2); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); + StructType actualSchema1 = schemaRDD1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + schemaRDD1.registerAsTable("jsonTable1"); + List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); + Assert.assertEquals(expectedResult, actual1); + + JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = schemaRDD2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + schemaRDD1.registerAsTable("jsonTable2"); + List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java new file mode 100644 index 0000000000000..52d07b5425cc3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class JavaRowSuite { + private byte byteValue; + private short shortValue; + private int intValue; + private long longValue; + private float floatValue; + private double doubleValue; + private BigDecimal decimalValue; + private boolean booleanValue; + private String stringValue; + private byte[] binaryValue; + private Timestamp timestampValue; + + @Before + public void setUp() { + byteValue = (byte)127; + shortValue = (short)32767; + intValue = 2147483647; + longValue = 9223372036854775807L; + floatValue = (float)3.4028235E38; + doubleValue = 1.7976931348623157E308; + decimalValue = new BigDecimal("1.7976931348623157E328"); + booleanValue = true; + stringValue = "this is a string"; + binaryValue = stringValue.getBytes(); + timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); + } + + @Test + public void constructSimpleRow() { + Row simpleRow = Row.create( + byteValue, // ByteType + new Byte(byteValue), + shortValue, // ShortType + new Short(shortValue), + intValue, // IntegerType + new Integer(intValue), + longValue, // LongType + new Long(longValue), + floatValue, // FloatType + new Float(floatValue), + doubleValue, // DoubleType + new Double(doubleValue), + decimalValue, // DecimalType + booleanValue, // BooleanType + new Boolean(booleanValue), + stringValue, // StringType + binaryValue, // BinaryType + timestampValue, // TimestampType + null // null + ); + + Assert.assertEquals(byteValue, simpleRow.getByte(0)); + Assert.assertEquals(byteValue, simpleRow.get(0)); + Assert.assertEquals(byteValue, simpleRow.getByte(1)); + Assert.assertEquals(byteValue, simpleRow.get(1)); + Assert.assertEquals(shortValue, simpleRow.getShort(2)); + Assert.assertEquals(shortValue, simpleRow.get(2)); + Assert.assertEquals(shortValue, simpleRow.getShort(3)); + Assert.assertEquals(shortValue, simpleRow.get(3)); + Assert.assertEquals(intValue, simpleRow.getInt(4)); + Assert.assertEquals(intValue, simpleRow.get(4)); + Assert.assertEquals(intValue, simpleRow.getInt(5)); + Assert.assertEquals(intValue, simpleRow.get(5)); + Assert.assertEquals(longValue, simpleRow.getLong(6)); + Assert.assertEquals(longValue, simpleRow.get(6)); + Assert.assertEquals(longValue, simpleRow.getLong(7)); + Assert.assertEquals(longValue, simpleRow.get(7)); + // When we create the row, we do not do any conversion + // for a float/double value, so we just set the delta to 0. + Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); + Assert.assertEquals(floatValue, simpleRow.get(8)); + Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); + Assert.assertEquals(floatValue, simpleRow.get(9)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); + Assert.assertEquals(doubleValue, simpleRow.get(10)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); + Assert.assertEquals(doubleValue, simpleRow.get(11)); + Assert.assertEquals(decimalValue, simpleRow.get(12)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); + Assert.assertEquals(booleanValue, simpleRow.get(13)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); + Assert.assertEquals(booleanValue, simpleRow.get(14)); + Assert.assertEquals(stringValue, simpleRow.getString(15)); + Assert.assertEquals(stringValue, simpleRow.get(15)); + Assert.assertEquals(binaryValue, simpleRow.get(16)); + Assert.assertEquals(timestampValue, simpleRow.get(17)); + Assert.assertEquals(true, simpleRow.isNullAt(18)); + Assert.assertEquals(null, simpleRow.get(18)); + } + + @Test + public void constructComplexRow() { + // Simple array + List simpleStringArray = Arrays.asList( + stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); + + // Simple map + Map simpleMap = new HashMap(); + simpleMap.put(stringValue + " (1)", longValue); + simpleMap.put(stringValue + " (2)", longValue - 1); + simpleMap.put(stringValue + " (3)", longValue - 2); + + // Simple struct + Row simpleStruct = Row.create( + doubleValue, stringValue, timestampValue, null); + + // Complex array + List> arrayOfMaps = Arrays.asList(simpleMap); + List arrayOfRows = Arrays.asList(simpleStruct); + + // Complex map + Map, Row> complexMap = new HashMap, Row>(); + complexMap.put(arrayOfRows, simpleStruct); + + // Complex struct + Row complexStruct = Row.create( + simpleStringArray, + simpleMap, + simpleStruct, + arrayOfMaps, + arrayOfRows, + complexMap, + null); + Assert.assertEquals(simpleStringArray, complexStruct.get(0)); + Assert.assertEquals(simpleMap, complexStruct.get(1)); + Assert.assertEquals(simpleStruct, complexStruct.get(2)); + Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); + Assert.assertEquals(arrayOfRows, complexStruct.get(4)); + Assert.assertEquals(complexMap, complexStruct.get(5)); + Assert.assertEquals(null, complexStruct.get(6)); + + // A very complex row + Row complexRow = Row.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); + Assert.assertEquals(arrayOfMaps, complexRow.get(0)); + Assert.assertEquals(arrayOfRows, complexRow.get(1)); + Assert.assertEquals(complexMap, complexRow.get(2)); + Assert.assertEquals(complexStruct, complexRow.get(3)); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java new file mode 100644 index 0000000000000..96a503962f7d1 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.List; +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.sql.types.util.DataTypeConversions; +import org.apache.spark.sql.api.java.types.DataType; +import org.apache.spark.sql.api.java.types.StructField; + +public class JavaSideDataTypeConversionSuite { + public void checkDataType(DataType javaDataType) { + org.apache.spark.sql.catalyst.types.DataType scalaDataType = + DataTypeConversions.asScalaDataType(javaDataType); + DataType actual = DataTypeConversions.asJavaDataType(scalaDataType); + Assert.assertEquals(javaDataType, actual); + } + + @Test + public void createDataTypes() { + // Simple DataTypes. + checkDataType(DataType.StringType); + checkDataType(DataType.BinaryType); + checkDataType(DataType.BooleanType); + checkDataType(DataType.TimestampType); + checkDataType(DataType.DecimalType); + checkDataType(DataType.DoubleType); + checkDataType(DataType.FloatType); + checkDataType(DataType.ByteType); + checkDataType(DataType.IntegerType); + checkDataType(DataType.LongType); + checkDataType(DataType.ShortType); + + // Simple ArrayType. + DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true); + checkDataType(simpleJavaArrayType); + + // Simple MapType. + DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType); + checkDataType(simpleJavaMapType); + + // Simple StructType. + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); + DataType simpleJavaStructType = DataType.createStructType(simpleFields); + checkDataType(simpleJavaStructType); + + // Complex StructType. + List complexFields = new ArrayList(); + complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true)); + complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true)); + complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true)); + complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false)); + DataType complexJavaStructType = DataType.createStructType(complexFields); + checkDataType(complexJavaStructType); + + // Complex ArrayType. + DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true); + checkDataType(complexJavaArrayType); + + // Complex MapType. + DataType complexJavaMapType = + DataType.createMapType(complexJavaStructType, complexJavaArrayType, false); + checkDataType(complexJavaMapType); + } + + @Test + public void illegalArgument() { + // ArrayType + try { + DataType.createArrayType(null, true); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // MapType + try { + DataType.createMapType(null, DataType.StringType); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(DataType.StringType, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(null, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // StructField + try { + DataType.createStructField(null, DataType.StringType, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField("name", null, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField(null, null, true); + } catch (IllegalArgumentException expectedException) { + } + + // StructType + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(null); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala new file mode 100644 index 0000000000000..cf7d79f42db1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -0,0 +1,58 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +class DataTypeSuite extends FunSuite { + + test("construct an ArrayType") { + val array = ArrayType(StringType) + + assert(ArrayType(StringType, false) === array) + } + + test("construct an MapType") { + val map = MapType(StringType, IntegerType) + + assert(MapType(StringType, IntegerType, true) === map) + } + + test("extract fields from a StructType") { + val struct = StructType( + StructField("a", IntegerType, true) :: + StructField("b", LongType, false) :: + StructField("c", StringType, true) :: + StructField("d", FloatType, true) :: Nil) + + assert(StructField("b", LongType, false) === struct("b")) + + intercept[IllegalArgumentException] { + struct("e") + } + + val expectedStruct = StructType( + StructField("b", LongType, false) :: + StructField("d", FloatType, true) :: Nil) + + assert(expectedStruct === struct(Set("b", "d"))) + intercept[IllegalArgumentException] { + struct(Set("b", "d", "e", "f")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala new file mode 100644 index 0000000000000..651cb735ab7d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class RowSuite extends FunSuite { + + test("create row") { + val expected = new GenericMutableRow(4) + expected.update(0, 2147483647) + expected.update(1, "this is a string") + expected.update(2, false) + expected.update(3, null) + val actual1 = Row(2147483647, "this is a string", false, null) + assert(expected.size === actual1.size) + assert(expected.getInt(0) === actual1.getInt(0)) + assert(expected.getString(1) === actual1.getString(1)) + assert(expected.getBoolean(2) === actual1.getBoolean(2)) + assert(expected(3) === actual1(3)) + + val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) + assert(expected.size === actual2.size) + assert(expected.getInt(0) === actual2.getInt(0)) + assert(expected.getString(1) === actual2.getString(1)) + assert(expected.getBoolean(2) === actual2.getBoolean(2)) + assert(expected(3) === actual2(3)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index de9e8aa4f62ed..bebb490645420 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -446,4 +444,66 @@ class SQLQuerySuite extends QueryTest { ) clear() } + + test("apply schema") { + val schema1 = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", IntegerType, true) :: Nil) + + val rowRDD1 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, v4) + } + + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerAsTable("applySchema1") + checkAnswer( + sql("SELECT * FROM applySchema1"), + (1, "A1", true, null) :: + (2, "B2", false, null) :: + (3, "C3", true, null) :: + (4, "D4", true, 2147483644) :: Nil) + + checkAnswer( + sql("SELECT f1, f4 FROM applySchema1"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD2 = applySchema(rowRDD2, schema2) + schemaRDD2.registerAsTable("applySchema2") + checkAnswer( + sql("SELECT * FROM applySchema2"), + (Seq(1, true), Map("A1" -> null)) :: + (Seq(2, false), Map("B2" -> null)) :: + (Seq(3, true), Map("C3" -> null)) :: + (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 330b20b315d63..213190e812026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -128,4 +128,11 @@ object TestData { case class TableName(tableName: String) TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + + val unparsedStrings = + TestSQLContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala new file mode 100644 index 0000000000000..46de6fe239228 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.types.util.DataTypeConversions +import org.scalatest.FunSuite + +import org.apache.spark.sql._ +import DataTypeConversions._ + +class ScalaSideDataTypeConversionSuite extends FunSuite { + + def checkDataType(scalaDataType: DataType) { + val javaDataType = asJavaDataType(scalaDataType) + val actual = asScalaDataType(javaDataType) + assert(scalaDataType === actual, s"Converted data type ${actual} " + + s"does not equal the expected data type ${scalaDataType}") + } + + test("convert data types") { + // Simple DataTypes. + checkDataType(StringType) + checkDataType(BinaryType) + checkDataType(BooleanType) + checkDataType(TimestampType) + checkDataType(DecimalType) + checkDataType(DoubleType) + checkDataType(FloatType) + checkDataType(ByteType) + checkDataType(IntegerType) + checkDataType(LongType) + checkDataType(ShortType) + + // Simple ArrayType. + val simpleScalaArrayType = ArrayType(StringType, true) + checkDataType(simpleScalaArrayType) + + // Simple MapType. + val simpleScalaMapType = MapType(StringType, LongType) + checkDataType(simpleScalaMapType) + + // Simple StructType. + val simpleScalaStructType = StructType( + StructField("a", DecimalType, false) :: + StructField("b", BooleanType, true) :: + StructField("c", LongType, true) :: + StructField("d", BinaryType, false) :: Nil) + checkDataType(simpleScalaStructType) + + // Complex StructType. + val complexScalaStructType = StructType( + StructField("simpleArray", simpleScalaArrayType, true) :: + StructField("simpleMap", simpleScalaMapType, true) :: + StructField("simpleStruct", simpleScalaStructType, true) :: + StructField("boolean", BooleanType, false) :: Nil) + checkDataType(complexScalaStructType) + + // Complex ArrayType. + val complexScalaArrayType = ArrayType(complexScalaStructType, true) + checkDataType(complexScalaArrayType) + + // Complex MapType. + val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false) + checkDataType(complexScalaMapType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index e765cfc83a397..9d9cfdd7c92e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ -protected case class Schema(output: Seq[Attribute]) extends LeafNode - class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -127,6 +123,18 @@ class JsonSuite extends QueryTest { checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType)) // StructType checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) @@ -164,16 +172,16 @@ class JsonSuite extends QueryTest { test("Primitive field and type inferring") { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -192,27 +200,28 @@ class JsonSuite extends QueryTest { test("Complex field and type inferring") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - val expectedSchema = - AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: - AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: - AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: - AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: - AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: - AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: - AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: - AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: - AttributeReference("arrayOfString", ArrayType(StringType), true)() :: - AttributeReference("arrayOfStruct", ArrayType( - StructType(StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true)() :: - AttributeReference("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true)() :: - AttributeReference("structWithArrayFields", StructType( + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType), true) :: + StructField("arrayOfInteger", ArrayType(IntegerType), true) :: + StructField("arrayOfLong", ArrayType(LongType), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true) :: + StructField("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType), true) :: - StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -301,15 +310,15 @@ class JsonSuite extends QueryTest { test("Type conflict in primitive field values") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - val expectedSchema = - AttributeReference("num_bool", StringType, true)() :: - AttributeReference("num_num_1", LongType, true)() :: - AttributeReference("num_num_2", DecimalType, true)() :: - AttributeReference("num_num_3", DoubleType, true)() :: - AttributeReference("num_str", StringType, true)() :: - AttributeReference("str_bool", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DecimalType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -426,15 +435,15 @@ class JsonSuite extends QueryTest { test("Type conflict in complex field values") { val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) - val expectedSchema = - AttributeReference("array", ArrayType(IntegerType), true)() :: - AttributeReference("num_struct", StringType, true)() :: - AttributeReference("str_array", StringType, true)() :: - AttributeReference("struct", StructType( - StructField("field", StringType, true) :: Nil), true)() :: - AttributeReference("struct_array", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("array", ArrayType(IntegerType), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", StringType, true) :: + StructField("struct", StructType( + StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -450,12 +459,12 @@ class JsonSuite extends QueryTest { test("Type conflict in array elements") { val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) - val expectedSchema = - AttributeReference("array1", ArrayType(StringType), true)() :: - AttributeReference("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil)), true)() :: Nil + val expectedSchema = StructType( + StructField("array1", ArrayType(StringType, true), true) :: + StructField("array2", ArrayType(StructType( + StructField("field", LongType, true) :: Nil)), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -475,15 +484,15 @@ class JsonSuite extends QueryTest { test("Handling missing fields") { val jsonSchemaRDD = jsonRDD(missingFields) - val expectedSchema = - AttributeReference("a", BooleanType, true)() :: - AttributeReference("b", LongType, true)() :: - AttributeReference("c", ArrayType(IntegerType), true)() :: - AttributeReference("d", StructType( - StructField("field", BooleanType, true) :: Nil), true)() :: - AttributeReference("e", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(IntegerType), true) :: + StructField("d", StructType( + StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") } @@ -494,16 +503,16 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val jsonSchemaRDD = jsonFile(path) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -518,4 +527,53 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("Applying schemas") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + val jsonSchemaRDD1 = jsonFile(path, schema) + + assert(schema === jsonSchemaRDD1.schema) + + jsonSchemaRDD1.registerAsTable("jsonTable1") + + checkAnswer( + sql("select * from jsonTable1"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + + val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + + assert(schema === jsonSchemaRDD2.schema) + + jsonSchemaRDD2.registerAsTable("jsonTable2") + + checkAnswer( + sql("select * from jsonTable2"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f0a61270daf05..b413373345eea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand @@ -260,9 +259,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -279,9 +278,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index ad7dc0ecdb1bf..354fcd53f303b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -152,8 +152,9 @@ private[hive] trait HiveInspectors { } def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index dff1d6a4b93bb..fa4e78439c26c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -200,7 +200,9 @@ object HiveMetastoreTypes extends RegexParsers { "varchar\\((\\d+)\\)".r ^^^ StringType protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + "array" ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } protected lazy val mapType: Parser[DataType] = "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { @@ -229,10 +231,10 @@ object HiveMetastoreTypes extends RegexParsers { } def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType) => + case MapType(keyType, valueType, _) => s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" case StringType => "string" case FloatType => "float" From 7c5fc28af42daaa6725af083d78c2372f3d0a338 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Wed, 30 Jul 2014 00:18:59 -0700 Subject: [PATCH 026/170] SPARK-2543: Allow user to set maximum Kryo buffer size Author: Koert Kuipers Closes #735 from koertkuipers/feat-kryo-max-buffersize and squashes the following commits: 15f6d81 [Koert Kuipers] change default for spark.kryoserializer.buffer.max.mb to 64mb and add some documentation 1bcc22c [Koert Kuipers] Merge branch 'master' into feat-kryo-max-buffersize 0c9f8eb [Koert Kuipers] make default for kryo max buffer size 16MB 143ec4d [Koert Kuipers] test resizable buffer in kryo Output 0732445 [Koert Kuipers] support setting maxCapacity to something different than capacity in kryo Output --- .../spark/serializer/KryoSerializer.scala | 3 +- .../serializer/KryoSerializerSuite.scala | 30 +++++++++++++++++++ docs/configuration.md | 16 +++++++--- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index fa79b25759153..e60b802a86a14 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -48,11 +48,12 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") - def newKryoOutput() = new KryoOutput(bufferSize) + def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 79280d1a06653..789b773bae316 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -209,6 +209,36 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } +class KryoSerializerResizableOutputSuite extends FunSuite { + import org.apache.spark.SparkConf + import org.apache.spark.SparkContext + import org.apache.spark.LocalSparkContext + import org.apache.spark.SparkException + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect === x) + LocalSparkContext.stop(sc) + } +} + object KryoTest { case class CaseClass(i: Int, s: String) {} diff --git a/docs/configuration.md b/docs/configuration.md index 2e6c85cc2bcca..ea69057b5be10 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -414,10 +414,18 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.mb 2 - Maximum object size to allow within Kryo (the library needs to create a buffer at least as - large as the largest single object you'll serialize). Increase this if you get a "buffer limit - exceeded" exception inside Kryo. Note that there will be one buffer per core on each - worker. + Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max.mb if needed. + + + + spark.kryoserializer.buffer.max.mb + 64 + + Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception + inside Kryo. From ee07541e99f0d262bf662b669b6542cf302ff39c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 30 Jul 2014 08:55:15 -0700 Subject: [PATCH 027/170] SPARK-2748 [MLLIB] [GRAPHX] Loss of precision for small arguments to Math.exp, Math.log In a few places in MLlib, an expression of the form `log(1.0 + p)` is evaluated. When p is so small that `1.0 + p == 1.0`, the result is 0.0. However the correct answer is very near `p`. This is why `Math.log1p` exists. Similarly for one instance of `exp(m) - 1` in GraphX; there's a special `Math.expm1` method. While the errors occur only for very small arguments, given their use in machine learning algorithms, this is entirely possible. Also note the related PR for Python: https://github.com/apache/spark/pull/1652 Author: Sean Owen Closes #1659 from srowen/SPARK-2748 and squashes the following commits: c5926d4 [Sean Owen] Use log1p, expm1 for better precision for tiny arguments --- .../org/apache/spark/graphx/util/GraphGenerators.scala | 6 ++++-- .../org/apache/spark/mllib/optimization/Gradient.scala | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 635514f09ece0..60149548ab852 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -100,8 +100,10 @@ object GraphGenerators { */ private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = { val rand = new Random() - val m = math.exp(mu + (sigma * sigma) / 2.0) - val s = math.sqrt((math.exp(sigma*sigma) - 1) * math.exp(2*mu + sigma*sigma)) + val sigmaSq = sigma * sigma + val m = math.exp(mu + sigmaSq / 2.0) + // expm1 is exp(m)-1 with better accuracy for tiny m + val s = math.sqrt(math.expm1(sigmaSq) * math.exp(2*mu + sigmaSq)) // Z ~ N(0, 1) var X: Double = maxVal diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 679842f831c2a..9d82f011e674a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -68,9 +68,9 @@ class LogisticGradient extends Gradient { val gradient = brzData * gradientMultiplier val loss = if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } (Vectors.fromBreeze(gradient), loss) @@ -89,9 +89,9 @@ class LogisticGradient extends Gradient { brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) if (label > 0) { - math.log(1 + math.exp(margin)) + math.log1p(math.exp(margin)) } else { - math.log(1 + math.exp(margin)) - margin + math.log1p(math.exp(margin)) - margin } } } From 774142f5556ac37fddf03cfa46eb23ca1bde2492 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 09:27:43 -0700 Subject: [PATCH 028/170] [SPARK-2521] Broadcast RDD object (instead of sending it along with every task) This is a resubmission of #1452. It was reverted because it broke the build. Currently (as of Spark 1.0.1), Spark sends RDD object (which contains closures) using Akka along with the task itself to the executors. This is inefficient because all tasks in the same stage use the same RDD object, but we have to send RDD object multiple times to the executors. This is especially bad when a closure references some variable that is very large. The current design led to users having to explicitly broadcast large variables. The patch uses broadcast to send RDD objects and the closures to executors, and use Akka to only send a reference to the broadcast RDD/closure along with the partition specific information for the task. For those of you who know more about the internals, Spark already relies on broadcast to send the Hadoop JobConf every time it uses the Hadoop input, because the JobConf is large. The user-facing impact of the change include: 1. Users won't need to decide what to broadcast anymore, unless they would want to use a large object multiple times in different operations 2. Task size will get smaller, resulting in faster scheduling and higher task dispatch throughput. In addition, the change will simplify some internals of Spark, eliminating the need to maintain task caches and the complex logic to broadcast JobConf (which also led to a deadlock recently). A simple way to test this: ```scala val a = new Array[Byte](1000*1000); scala.util.Random.nextBytes(a); sc.parallelize(1 to 1000, 1000).map { x => a; x }.groupBy { x => a; x }.count ``` Numbers on 3 r3.8xlarge instances on EC2 ``` master branch: 5.648436068 s, 4.715361895 s, 5.360161877 s with this change: 3.416348793 s, 1.477846558 s, 1.553432156 s ``` Author: Reynold Xin Closes #1498 from rxin/broadcast-task and squashes the following commits: f7364db [Reynold Xin] Code review feedback. f8535dc [Reynold Xin] Fixed the style violation. 252238d [Reynold Xin] Serialize the final task closure as well as ShuffleDependency in taskBinary. 111007d [Reynold Xin] Fix broadcast tests. 797c247 [Reynold Xin] Properly send SparkListenerStageSubmitted and SparkListenerStageCompleted. bab1d8b [Reynold Xin] Check for NotSerializableException in submitMissingTasks. cf38450 [Reynold Xin] Use TorrentBroadcastFactory. 991c002 [Reynold Xin] Use HttpBroadcast. de779f8 [Reynold Xin] Fix TaskContextSuite. cc152fc [Reynold Xin] Don't cache the RDD broadcast variable. d256b45 [Reynold Xin] Fixed unit test failures. One more to go. cae0af3 [Reynold Xin] [SPARK-2521] Broadcast RDD object (instead of sending it along with every task). --- .../scala/org/apache/spark/Dependency.scala | 28 ++-- .../scala/org/apache/spark/SparkContext.scala | 2 - .../main/scala/org/apache/spark/rdd/RDD.scala | 11 +- .../apache/spark/rdd/RDDCheckpointData.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 87 ++++++++---- .../apache/spark/scheduler/ResultTask.scala | 118 +++------------- .../spark/scheduler/ShuffleMapTask.scala | 129 ++++-------------- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../apache/spark/ContextCleanerSuite.scala | 71 ++++++---- .../scala/org/apache/spark/rdd/RDDSuite.scala | 8 +- .../spark/scheduler/TaskContextSuite.scala | 24 ++-- .../ui/jobs/JobProgressListenerSuite.scala | 11 +- 12 files changed, 198 insertions(+), 302 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 09a60571238ea..3935c8772252e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable +abstract class Dependency[T] extends Serializable { + def rdd: RDD[T] +} /** @@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable * partition of the child RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { +abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] + + override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient rdd: RDD[_ <: Product2[K, V]], + @transient _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + extends Dependency[Product2[K, V]] { + + override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] - val shuffleId: Int = rdd.context.newShuffleId() + val shuffleId: Int = _rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( - shuffleId, rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( + shuffleId, _rdd.partitions.size, this) - rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..fb4c86716bb8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging { // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a6abc49c5359e..726b3f2bbeea7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } + def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } + def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) // ======================================================================= // Other internal methods and fields diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index c3b2a33fb54d0..f67e5f1857979 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} +// Used for synchronization +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc6142ab79d03..50186d097a632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger @@ -35,6 +35,7 @@ import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD @@ -114,6 +115,10 @@ class DAGScheduler( private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ private def initializeEventProcessActor() { @@ -361,9 +366,6 @@ class DAGScheduler( // data structures based on StageId stageIdToStage -= stageId - ShuffleMapTask.removeStage(stageId) - ResultTask.removeStage(stageId) - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -691,49 +693,83 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties + } else { + // this stage will be assigned to "default" pool + null + } + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = + if (stage.isShuffleMap) { + closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() + } else { + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString) + runningStages -= stage + return + case NonFatal(e) => + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + val part = stage.rdd.partitions(p) + tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ResultTask(stage.id, taskBinary, part, locs, id) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } - if (tasks.size > 0) { - runningStages += stage - // SparkListenerStageSubmitted should be posted before testing whether tasks are - // serializable. If tasks are not serializable, a SparkListenerStageCompleted event - // will be posted, which should always come after a corresponding SparkListenerStageSubmitted - // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and // cluster schedulers. + // + // We've already serialized RDDs and closures in taskBinary, but here we check for all other + // objects such as Partition. try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + closureSerializer.serialize(tasks.head) } catch { case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) @@ -752,6 +788,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.info.submissionTime = Some(clock.getTime()) } else { + // Because we posted SparkListenerStageSubmitted earlier, we should post + // SparkListenerStageCompleted here in case there are no tasks to run. + listenerBus.post(SparkListenerStageCompleted(stage.info)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index bbf9f7388b074..d09fd7aa57642 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,134 +17,56 @@ package org.apache.spark.scheduler -import scala.language.existentials +import java.nio.ByteBuffer import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = - { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = - { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - (rdd, func) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD /** * A task that sends back the output to the driver application. * - * See [[org.apache.spark.scheduler.Task]] for more information. + * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd input to func - * @param func a function to apply on a partition of the RDD - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * partition of the given RDD. Once deserialized, the type should be + * (RDD[T], (TaskContext, Iterator[T]) => U). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient locs: Seq[TaskLocation], - var outputId: Int) - extends Task[U](stageId, _partitionId) with Externalizable { + val outputId: Int) + extends Task[U](stageId, partition.index) with Serializable { - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) null else rdd.partitions(partitionId) - - @transient private val preferredLocs: Seq[TaskLocation] = { + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { + // Deserialize the RDD and the func using the broadcast variables. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(split, context)) + func(context, rdd.iterator(partition, context)) } finally { context.executeOnCompleteCallbacks() } } + // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partitionId = in.readInt() - outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fdaf1de83f051..11255c07469d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,134 +17,55 @@ package org.apache.spark.scheduler -import scala.language.existentials - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import java.nio.ByteBuffer -import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] - (rdd, dep) - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - HashMap(set.toSeq: _*) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - /** - * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner - * specified in the ShuffleDependency). - * - * See [[org.apache.spark.scheduler.Task]] for more information. - * +* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner +* specified in the ShuffleDependency). +* +* See [[org.apache.spark.scheduler.Task]] for more information. +* * @param stageId id of the stage this task belongs to - * @param rdd the final RDD in this stage - * @param dep the ShuffleDependency - * @param _partitionId index of the number in the RDD + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * the type should be (RDD[_], ShuffleDependency[_, _, _]). + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_, _, _], - _partitionId: Int, + taskBinary: Broadcast[Array[Byte]], + partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partitionId) - with Externalizable - with Logging { + extends Task[MapStatus](stageId, partition.index) with Logging { - protected def this() = this(0, null, null, 0, null) + /** A constructor used only in test suites. This does not require passing in an RDD. */ + def this(partitionId: Int) { + this(0, null, new Partition { override def index = 0 }, null) + } @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partitionId) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partitionId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - override def runTask(context: TaskContext): MapStatus = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 69f65b4bdccb1..f8fbb3ad6d4a1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..ad20f9b937ac1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark import java.lang.ref.WeakReference +import org.apache.spark.broadcast.Broadcast + +import scala.collection.mutable import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.language.postfixOps @@ -52,9 +55,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } - test("cleanup RDD") { - val rdd = newRDD.persist() + val rdd = newRDD().persist() val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) @@ -67,7 +69,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() val collected = rdd.collect().toList val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) @@ -80,7 +82,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup broadcast") { - val broadcast = newBroadcast + val broadcast = newBroadcast() val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup @@ -89,7 +91,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup RDD") { - var rdd = newRDD.persist() + var rdd = newRDD().persist() rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference @@ -107,7 +109,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup shuffle") { - var rdd = newShuffleRDD + var rdd = newShuffleRDD() rdd.count() // Test that GC does not cause shuffle cleanup due to a strong reference @@ -125,7 +127,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup broadcast") { - var broadcast = newBroadcast + var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) @@ -144,11 +146,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -162,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { @@ -175,11 +184,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -193,21 +202,29 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } //------ Helper functions ------ - def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + private def newRDD() = sc.makeRDD(1 to 10) + private def newPairRDD() = newRDD().map(_ -> 1) + private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) + private def newBroadcast() = sc.broadcast(1 to 100) + + private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) } } - val rdd = newShuffleRDD + val rdd = newShuffleRDD() // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) @@ -216,34 +233,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo (rdd, shuffleDeps) } - def randomRdd = { + private def randomRdd() = { val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) + case 0 => newRDD() + case 1 => newShuffleRDD() + case 2 => newPairRDD.join(newPairRDD()) } if (Random.nextBoolean()) rdd.persist() rdd.count() rdd } - def randomBroadcast = { + private def randomBroadcast() = { sc.broadcast(Random.nextInt(Int.MaxValue)) } /** Run GC and make sure it actually has run */ - def runGC() { + private def runGC() { val weakRef = new WeakReference(new Object()) val startTime = System.currentTimeMillis System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. // Wait until a weak reference object has been GCed - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() Thread.sleep(200) } } - def cleaner = sc.cleaner.get + private def cleaner = sc.cleaner.get } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index fdc83bc0a5f8e..4953d565ae83a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -155,19 +155,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { override def getPartitions: Array[Partition] = Array(onlySplit) override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - Array(1, 2, 3, 4).iterator - } + throw new Exception("injected failure") } }.cache() val thrown = intercept[Exception]{ rdd.collect() } assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) } test("empty RDD") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8bb5317cd2875..270f7e661045a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,31 +20,35 @@ package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.apache.spark.LocalSparkContext -import org.apache.spark.Partition -import org.apache.spark.SparkContext -import org.apache.spark.TaskContext +import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { - var completed = false + TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) + context.addOnCompleteCallback(() => TaskContextSuite.completed = true) sys.error("failed") } } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val task = new ResultTask[String, String]( + 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) } - assert(completed === true) + assert(TaskContextSuite.completed === true) } +} - case class StubPartition(val index: Int) extends Partition +private object TaskContextSuite { + @volatile var completed = false } + +private case class StubPartition(index: Int) extends Partition diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index b52f81877d557..86a271eb67000 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -66,7 +67,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0, null, null, 0, null) + var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) @@ -76,14 +77,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) .shuffleRead === 2000) @@ -91,7 +92,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) .shuffleRead === 1000) @@ -103,7 +104,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val metrics = new TaskMetrics() val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0, null, null, 0, null) + val task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) // Go through all the failure cases to make sure we are counting them as failures. From 3bc3f1801e3347e02cbecdd8e941003430155da2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 09:28:53 -0700 Subject: [PATCH 029/170] [SPARK-2747] git diff --dirstat can miss sql changes and not run Hive tests dev/run-tests use "git diff --dirstat master" to check whether sql is changed. However, --dirstat won't show sql if sql's change is negligible (e.g. 1k loc change in core, and only 1 loc change in hive). We should use "git diff --name-only master" instead. Author: Reynold Xin Closes #1656 from rxin/hiveTest and squashes the following commits: f5eab9f [Reynold Xin] [SPARK-2747] git diff --dirstat can miss sql changes and not run Hive tests. --- dev/run-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 98ec969dc1b37..795d16a4d983d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -37,7 +37,7 @@ JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..* # Partial solution for SPARK-1455. Only run Hive tests if there are sql changes. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - diffs=`git diff --dirstat master | awk '{ print $2; }' | grep "^sql/"` + diffs=`git diff --name-only master | grep "^sql/"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." export _RUN_SQL_TESTS=true # exported for PySpark tests From e3d85b7e40073b05e2588583e9d8db11366c2f7b Mon Sep 17 00:00:00 2001 From: Naftali Harris Date: Wed, 30 Jul 2014 09:56:59 -0700 Subject: [PATCH 030/170] Avoid numerical instability This avoids basically doing 1 - 1, for example: ```python >>> from math import exp >>> margin = -40 >>> 1 - 1 / (1 + exp(margin)) 0.0 >>> exp(margin) / (1 + exp(margin)) 4.248354255291589e-18 >>> ``` Author: Naftali Harris Closes #1652 from naftaliharris/patch-2 and squashes the following commits: 0d55a9f [Naftali Harris] Avoid numerical instability --- python/pyspark/mllib/classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 9e28dfbb9145d..2bbb9c3fca315 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -66,7 +66,8 @@ def predict(self, x): if margin > 0: prob = 1 / (1 + exp(-margin)) else: - prob = 1 - 1 / (1 + exp(margin)) + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) return 1 if prob > 0.5 else 0 From fc47bb6967e0df40870413e09d37aa9b90248f43 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 30 Jul 2014 11:00:11 -0700 Subject: [PATCH 031/170] [SPARK-2544][MLLIB] Improve ALS algorithm resource usage Author: GuoQiang Li Author: witgo Closes #929 from witgo/improve_als and squashes the following commits: ea25033 [GuoQiang Li] checkpoint products 3,6,9 ... 154dccf [GuoQiang Li] checkpoint products only c5779ff [witgo] Improve ALS algorithm resource usage --- .../scala/org/apache/spark/mllib/recommendation/ALS.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 5356790cb5339..d208cfb917f3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -255,6 +255,9 @@ class ALS private ( rank, lambda, alpha, YtY) previousProducts.unpersist() logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter").persist() val XtX = Some(sc.broadcast(computeYtY(products))) val previousUsers = users @@ -268,6 +271,9 @@ class ALS private ( logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, rank, lambda, alpha, YtY = None) + if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { + products.checkpoint() + } products.setName(s"products-$iter") logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, From ff511bacf223e19244f5f6114d60af7dcadeda4d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 11:45:24 -0700 Subject: [PATCH 032/170] [SPARK-2746] Set SBT_MAVEN_PROFILES only when it is not set explicitly by the user. Author: Reynold Xin Closes #1655 from rxin/SBT_MAVEN_PROFILES and squashes the following commits: b268c4b [Reynold Xin] [SPARK-2746] Set SBT_MAVEN_PROFILES only when it is not set explicitly by the user. --- dev/run-tests | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 795d16a4d983d..c95ef8a5743fc 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,7 +21,10 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR -export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +if [ -z "$SBT_MAVEN_PROFILES" ]; then + export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +fi +echo "SBT_MAVEN_PROFILES=\"$SBT_MAVEN_PROFILES\"" # Remove work directory rm -rf ./work From f2eb84fe737e6b06f5625640b209cf02f80732cf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 12:24:35 -0700 Subject: [PATCH 033/170] Wrap FWDIR in quotes. --- dev/run-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index c95ef8a5743fc..f2b523b996617 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -19,7 +19,7 @@ # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" if [ -z "$SBT_MAVEN_PROFILES" ]; then export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" From 95cf203936c412bc689bd2345fec7f9ad3648c25 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 12:33:42 -0700 Subject: [PATCH 034/170] Wrap FWDIR in quotes in dev/check-license. --- dev/check-license | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/check-license b/dev/check-license index fbd2dd465bb18..7a603bf0180ad 100755 --- a/dev/check-license +++ b/dev/check-license @@ -51,7 +51,7 @@ acquire_rat_jar () { # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" if test -x "$JAVA_HOME/bin/java"; then declare java_cmd="$JAVA_HOME/bin/java" From 0feb349ea07361f0363117404ffc9797c2c80dd1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 13:04:20 -0700 Subject: [PATCH 035/170] More wrapping FWDIR in quotes. --- dev/mima | 2 +- dev/run-tests-jenkins | 2 +- make-distribution.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/mima b/dev/mima index 7857294f61caf..4c3e65039b160 100755 --- a/dev/mima +++ b/dev/mima @@ -22,7 +22,7 @@ set -e # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 8dda671e976ce..3076eb847b420 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,7 @@ # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" -cd $FWDIR +cd "$FWDIR" COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" diff --git a/make-distribution.sh b/make-distribution.sh index c08093f46b61f..0a3283ecec6f8 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -150,7 +150,7 @@ else fi # Build uber fat JAR -cd $FWDIR +cd "$FWDIR" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" From 2248891a43d93cf2c05580211faf1e4f8dc7932d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 30 Jul 2014 13:11:09 -0700 Subject: [PATCH 036/170] [SQL] Fix compiling of catalyst docs. Author: Michael Armbrust Closes #1653 from marmbrus/fixDocs and squashes the following commits: 0aa1feb [Michael Armbrust] Fix compiling of catalyst docs. --- project/SparkBuild.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e2dab0f9f79ea..672343fbbed2e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -194,7 +194,10 @@ object Flume { object Catalyst { lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)) + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), + // Quasiquotes break compiling scala doc... + // TODO: Investigate fixing this. + sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) } object SQL { From 437dc8c5b54f0dcf9564c1fb07e8dce9e771c8cd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 13:17:14 -0700 Subject: [PATCH 037/170] dev/check-license wrap folders in quotes. --- dev/check-license | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/check-license b/dev/check-license index 7a603bf0180ad..00bb20c133b7d 100755 --- a/dev/check-license +++ b/dev/check-license @@ -27,7 +27,7 @@ acquire_rat_jar () { if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f ${JAR} ]; then + if [ ! -f "$JAR" ]; then # Download printf "Attempting to fetch rat\n" JAR_DL=${JAR}.part @@ -40,10 +40,10 @@ acquire_rat_jar () { exit -1 fi fi - if [ ! -f ${JAR} ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" - exit -1 + if [ ! -f "$JAR" ]; then + # We failed to download + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi printf "Launching rat from ${JAR}\n" fi From 94d1f46fc43c0cb85125f757fb40db9271caf1f4 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Wed, 30 Jul 2014 13:19:05 -0700 Subject: [PATCH 038/170] [SPARK-2024] Add saveAsSequenceFile to PySpark JIRA issue: https://issues.apache.org/jira/browse/SPARK-2024 This PR is a followup to #455 and adds capabilities for saving PySpark RDDs using SequenceFile or any Hadoop OutputFormats. * Added RDD methods ```saveAsSequenceFile```, ```saveAsHadoopFile``` and ```saveAsHadoopDataset```, for both old and new MapReduce APIs. * Default converter for converting common data types to Writables. Users may specify custom converters to convert to desired data types. * No out-of-box support for reading/writing arrays, since ArrayWritable itself doesn't have a no-arg constructor for creating an empty instance upon reading. Users need to provide ArrayWritable subtypes. Custom converters for converting arrays to suitable ArrayWritable subtypes are also needed when writing. When reading, the default converter will convert any custom ArrayWritable subtypes to ```Object[]``` and they get pickled to Python tuples. * Added HBase and Cassandra output examples to show how custom output formats and converters can be used. cc MLnick mateiz ahirreddy pwendell Author: Kan Zhang Closes #1338 from kanzhang/SPARK-2024 and squashes the following commits: c01e3ef [Kan Zhang] [SPARK-2024] code formatting 6591e37 [Kan Zhang] [SPARK-2024] renaming pickled -> pickledRDD d998ad6 [Kan Zhang] [SPARK-2024] refectoring to get method params below 10 57a7a5e [Kan Zhang] [SPARK-2024] correcting typo 75ca5bd [Kan Zhang] [SPARK-2024] Better type checking for batch serialized RDD 0bdec55 [Kan Zhang] [SPARK-2024] Refactoring newly added tests 9f39ff4 [Kan Zhang] [SPARK-2024] Adding 2 saveAsHadoopDataset tests 0c134f3 [Kan Zhang] [SPARK-2024] Test refactoring and adding couple unbatched cases 7a176df [Kan Zhang] [SPARK-2024] Add saveAsSequenceFile to PySpark --- .../spark/api/python/PythonHadoopUtil.scala | 82 ++++- .../apache/spark/api/python/PythonRDD.scala | 247 +++++++++++--- .../apache/spark/api/python/SerDeUtil.scala | 61 +++- .../WriteInputFormatTestDataGenerator.scala | 69 +++- docs/programming-guide.md | 52 ++- .../src/main/python/cassandra_outputformat.py | 83 +++++ examples/src/main/python/hbase_inputformat.py | 3 +- .../src/main/python/hbase_outputformat.py | 65 ++++ .../CassandraConverters.scala | 24 +- .../pythonconverters/HBaseConverter.scala | 33 -- .../pythonconverters/HBaseConverters.scala | 70 ++++ python/pyspark/context.py | 51 ++- python/pyspark/rdd.py | 114 +++++++ python/pyspark/tests.py | 317 +++++++++++++++++- 14 files changed, 1085 insertions(+), 186 deletions(-) create mode 100644 examples/src/main/python/cassandra_outputformat.py create mode 100644 examples/src/main/python/hbase_outputformat.py delete mode 100644 examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index adaa1ef6cf9ff..f3b05e1243045 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.python +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.Logging +import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -31,13 +32,14 @@ import org.apache.spark.annotation.Experimental * transformation code by overriding the convert method. */ @Experimental -trait Converter[T, U] extends Serializable { +trait Converter[T, + U] extends Serializable { def convert(obj: T): U } private[python] object Converter extends Logging { - def getInstance(converterClass: Option[String]): Converter[Any, Any] = { + def getInstance(converterClass: Option[String], + defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] @@ -49,7 +51,7 @@ private[python] object Converter extends Logging { logError(s"Failed to load converter: $cc") throw err } - }.getOrElse { new DefaultConverter } + }.getOrElse { defaultConverter } } } @@ -57,7 +59,9 @@ private[python] object Converter extends Logging { * A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects. * Other objects are passed through without conversion. */ -private[python] class DefaultConverter extends Converter[Any, Any] { +private[python] class WritableToJavaConverter( + conf: Broadcast[SerializableWritable[Configuration]], + batchSize: Int) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -72,17 +76,30 @@ private[python] class DefaultConverter extends Converter[Any, Any] { case fw: FloatWritable => fw.get() case t: Text => t.toString case bw: BooleanWritable => bw.get() - case byw: BytesWritable => byw.getBytes + case byw: BytesWritable => + val bytes = new Array[Byte](byw.getLength) + System.arraycopy(byw.getBytes(), 0, bytes, 0, byw.getLength) + bytes case n: NullWritable => null - case aw: ArrayWritable => aw.get().map(convertWritable(_)) - case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) => - (convertWritable(k), convertWritable(v)) - }.toMap) + case aw: ArrayWritable => + // Due to erasure, all arrays appear as Object[] and they get pickled to Python tuples. + // Since we can't determine element types for empty arrays, we will not attempt to + // convert to primitive arrays (which get pickled to Python arrays). Users may want + // write custom converters for arrays if they know the element types a priori. + aw.get().map(convertWritable(_)) + case mw: MapWritable => + val map = new java.util.HashMap[Any, Any]() + mw.foreach { case (k, v) => + map.put(convertWritable(k), convertWritable(v)) + } + map + case w: Writable => + if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w case other => other } } - def convert(obj: Any): Any = { + override def convert(obj: Any): Any = { obj match { case writable: Writable => convertWritable(writable) @@ -92,6 +109,47 @@ private[python] class DefaultConverter extends Converter[Any, Any] { } } +/** + * A converter that converts common types to [[org.apache.hadoop.io.Writable]]. Note that array + * types are not supported since the user needs to subclass [[org.apache.hadoop.io.ArrayWritable]] + * to set the type properly. See [[org.apache.spark.api.python.DoubleArrayWritable]] and + * [[org.apache.spark.api.python.DoubleArrayToWritableConverter]] for an example. They are used in + * PySpark RDD `saveAsNewAPIHadoopFile` doctest. + */ +private[python] class JavaToWritableConverter extends Converter[Any, Writable] { + + /** + * Converts common data types to [[org.apache.hadoop.io.Writable]]. Note that array types are not + * supported out-of-the-box. + */ + private def convertToWritable(obj: Any): Writable = { + import collection.JavaConversions._ + obj match { + case i: java.lang.Integer => new IntWritable(i) + case d: java.lang.Double => new DoubleWritable(d) + case l: java.lang.Long => new LongWritable(l) + case f: java.lang.Float => new FloatWritable(f) + case s: java.lang.String => new Text(s) + case b: java.lang.Boolean => new BooleanWritable(b) + case aob: Array[Byte] => new BytesWritable(aob) + case null => NullWritable.get() + case map: java.util.Map[_, _] => + val mapWritable = new MapWritable() + map.foreach { case (k, v) => + mapWritable.put(convertToWritable(k), convertToWritable(v)) + } + mapWritable + case other => throw new SparkException( + s"Data of type ${other.getClass.getName} cannot be used") + } + } + + override def convert(obj: Any): Writable = obj match { + case writable: Writable => writable + case other => convertToWritable(other) + } +} + /** Utilities for working with Python objects <-> Hadoop-related objects */ private[python] object PythonHadoopUtil { @@ -118,7 +176,7 @@ private[python] object PythonHadoopUtil { /** * Converts an RDD of key-value pairs, where key and/or value could be instances of - * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] + * [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa. */ def convertRDD[K, V](rdd: RDD[(K, V)], keyConverter: Converter[Any, Any], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f551a59ee3fe8..a9d758bf998c3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,15 +23,18 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.language.existentials import scala.reflect.ClassTag import scala.util.Try import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.{InputFormat, JobConf} -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -365,19 +368,17 @@ private[spark] object PythonRDD extends Logging { valueClassMaybeNull: String, keyConverterClass: String, valueConverterClass: String, - minSplits: Int) = { + minSplits: Int, + batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -394,17 +395,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -421,15 +421,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]]( @@ -439,18 +440,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc) } - rdd } /** @@ -467,17 +464,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { - val conf = PythonHadoopUtil.mapToConf(confAsMap) - val baseConf = sc.hadoopConfiguration() - val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { + val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } /** @@ -494,15 +490,16 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String], + batchSize: Int) = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val keyConverter = Converter.getInstance(Option(keyConverterClass)) - val valueConverter = Converter.getInstance(Option(valueConverterClass)) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) - JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new WritableToJavaConverter(confBroadcasted, batchSize)) + JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]]( @@ -512,18 +509,14 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] - implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] - implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] - val kc = kcm.runtimeClass.asInstanceOf[Class[K]] - val vc = vcm.runtimeClass.asInstanceOf[Class[V]] - val fc = fcm.runtimeClass.asInstanceOf[Class[F]] - val rdd = if (path.isDefined) { + val kc = Class.forName(keyClass).asInstanceOf[Class[K]] + val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc) } - rdd } def writeUTF(str: String, dataOut: DataOutputStream) { @@ -562,6 +555,152 @@ private[spark] object PythonRDD extends Logging { } } + private def getMergedConf(confAsMap: java.util.HashMap[String, String], + baseConf: Configuration): Configuration = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + PythonHadoopUtil.mergeConfs(baseConf, conf) + } + + private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null, + valueConverterClass: String = null): (Class[_], Class[_]) = { + // Peek at an element to figure out key/value types. Since Writables are not serializable, + // we cannot call first() on the converted RDD. Instead, we call first() on the original RDD + // and then convert locally. + val (key, value) = rdd.first() + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + (kc.convert(key).getClass, vc.convert(value).getClass) + } + + private def getKeyValueTypes(keyClass: String, valueClass: String): + Option[(Class[_], Class[_])] = { + for { + k <- Option(keyClass) + v <- Option(valueClass) + } yield (Class.forName(k), Class.forName(v)) + } + + private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, + defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = { + val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter) + val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter) + (keyConverter, valueConverter) + } + + /** + * Convert an RDD of key-value pairs from internal types to serializable types suitable for + * output, or vice versa. + */ + private def convertRDD[K, V](rdd: RDD[(K, V)], + keyConverterClass: String, + valueConverterClass: String, + defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = { + val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + defaultConverter) + PythonHadoopUtil.convertRDD(rdd, kc, vc) + } + + /** + * Output a Python RDD of key-value pairs as a Hadoop SequenceFile using the Writable types + * we convert from the RDD's key and value types. Note that keys and values can't be + * [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java + * `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system. + */ + def saveAsSequenceFile[K, V, C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + compressionCodecClass: String) = { + saveAsHadoopFile( + pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", + null, null, null, null, new java.util.HashMap(), compressionCodecClass) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using old Hadoop + * `OutputFormat` in mapred package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String], + compressionCodecClass: String) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using new Hadoop + * `OutputFormat` in mapreduce package. Keys and values are converted to suitable output + * types using either user specified converters or, if not specified, + * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types + * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in + * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of + * this RDD. + */ + def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + path: String, + outputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) + val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( + inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) + val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) + val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, + new JavaToWritableConverter) + val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) + } + + /** + * Output a Python RDD of key-value pairs to any Hadoop file system, using a Hadoop conf + * converted from the passed-in `confAsMap`. The conf should set relevant output params ( + * e.g., output path, output format, etc), in the same way as it would be configured for + * a Hadoop MapReduce job. Both old and new Hadoop OutputFormat APIs are supported + * (mapred vs. mapreduce). Keys/values are converted for output using either user specified + * converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]]. + */ + def saveAsHadoopDataset[K, V]( + pyRDD: JavaRDD[Array[Byte]], + batchSerialized: Boolean, + confAsMap: java.util.HashMap[String, String], + keyConverterClass: String, + valueConverterClass: String, + useNewAPI: Boolean) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), + keyConverterClass, valueConverterClass, new JavaToWritableConverter) + if (useNewAPI) { + converted.saveAsNewAPIHadoopDataset(conf) + } else { + converted.saveAsHadoopDataset(new JobConf(conf)) + } + } + /** * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by * PySpark. diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 9a012e7254901..efc9009c088a8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,13 +17,14 @@ package org.apache.spark.api.python -import scala.util.Try -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging -import scala.util.Success +import scala.collection.JavaConversions._ import scala.util.Failure -import net.razorvine.pickle.Pickler +import scala.util.Try +import net.razorvine.pickle.{Unpickler, Pickler} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[python] object SerDeUtil extends Logging { @@ -65,20 +66,52 @@ private[python] object SerDeUtil extends Logging { * by PySpark. By default, if serialization fails, toString is called and the string * representation is serialized */ - def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = { + def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) rdd.mapPartitions { iter => val pickle = new Pickler - iter.map { case (k, v) => - if (keyFailed && valueFailed) { - pickle.dumps(Array(k.toString, v.toString)) - } else if (keyFailed) { - pickle.dumps(Array(k.toString, v)) - } else if (!keyFailed && valueFailed) { - pickle.dumps(Array(k, v.toString)) + val cleaned = iter.map { case (k, v) => + val key = if (keyFailed) k.toString else k + val value = if (valueFailed) v.toString else v + Array[Any](key, value) + } + if (batchSize > 1) { + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + } else { + cleaned.map(pickle.dumps(_)) + } + } + } + + /** + * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. + */ + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def isPair(obj: Any): Boolean = { + Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + obj.asInstanceOf[Array[_]].length == 2 + } + pyRDD.mapPartitions { iter => + val unpickle = new Unpickler + val unpickled = + if (batchSerialized) { + iter.flatMap { batch => + unpickle.loads(batch) match { + case objs: java.util.List[_] => collectionAsScalaIterable(objs) + case other => throw new SparkException( + s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") + } + } } else { - pickle.dumps(Array(k, v)) + iter.map(unpickle.loads(_)) } + unpickled.map { + case obj if isPair(obj) => + // we only accept (K, V) + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index f0e3fb9aff5a0..d11db978b842e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -17,15 +17,16 @@ package org.apache.spark.api.python -import org.apache.spark.SparkContext -import org.apache.hadoop.io._ -import scala.Array import java.io.{DataOutput, DataInput} +import java.nio.charset.Charset + +import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.{SparkContext, SparkException} /** - * A class to test MsgPack serialization on the Scala side, that will be deserialized + * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python * @param str * @param int @@ -54,7 +55,13 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } } -class TestConverter extends Converter[Any, Any] { +private[python] class TestInputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + obj.asInstanceOf[IntWritable].get().toChar + } +} + +private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ override def convert(obj: Any) = { val m = obj.asInstanceOf[MapWritable] @@ -62,6 +69,38 @@ class TestConverter extends Converter[Any, Any] { } } +private[python] class TestOutputKeyConverter extends Converter[Any, Any] { + override def convert(obj: Any) = { + new Text(obj.asInstanceOf[Int].toString) + } +} + +private[python] class TestOutputValueConverter extends Converter[Any, Any] { + import collection.JavaConversions._ + override def convert(obj: Any) = { + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + } +} + +private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) + +private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { + override def convert(obj: Any) = obj match { + case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => + val daw = new DoubleArrayWritable + daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) + daw + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + +private[python] class WritableToDoubleArrayConverter extends Converter[Any, Array[Double]] { + override def convert(obj: Any): Array[Double] = obj match { + case daw : DoubleArrayWritable => daw.get().map(_.asInstanceOf[DoubleWritable].get()) + case other => throw new SparkException(s"Data of type $other is not supported") + } +} + /** * This object contains method to generate SequenceFile test data and write it to a * given directory (probably a temp directory) @@ -97,7 +136,8 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath) + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) sc.parallelize(intKeys).map{ case (k, v) => @@ -106,19 +146,20 @@ object WriteInputFormatTestDataGenerator { // Create test data for ArrayWritable val data = Seq( - (1, Array(1.0, 2.0, 3.0)), + (1, Array()), (2, Array(3.0, 4.0, 5.0)), (3, Array(4.0, 5.0, 6.0)) ) sc.parallelize(data, numSlices = 2) .map{ case (k, v) => - (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_)))) - }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath) + val va = new DoubleArrayWritable + va.set(v.map(new DoubleWritable(_))) + (new IntWritable(k), va) + }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, DoubleArrayWritable]](arrPath) // Create test data for MapWritable, with keys DoubleWritable and values Text val mapData = Seq( - (1, Map(2.0 -> "aa")), - (2, Map(3.0 -> "bb")), + (1, Map()), (2, Map(1.0 -> "cc")), (3, Map(2.0 -> "dd")), (2, Map(1.0 -> "aa")), @@ -126,9 +167,9 @@ object WriteInputFormatTestDataGenerator { ) sc.parallelize(mapData, numSlices = 2).map{ case (i, m) => val mw = new MapWritable() - val k = m.keys.head - val v = m.values.head - mw.put(new DoubleWritable(k), new Text(v)) + m.foreach { case (k, v) => + mw.put(new DoubleWritable(k), new Text(v)) + } (new IntWritable(i), mw) }.saveAsSequenceFile(mapPath) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 90c69713019f2..a88bf27add883 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -383,16 +383,16 @@ Apart from text files, Spark's Python API also supports several other data forma * `RDD.saveAsPickleFile` and `SparkContext.pickleFile` support saving an RDD in a simple format consisting of pickled Python objects. Batching is used on pickle serialization, with default batch size 10. -* Details on reading `SequenceFile` and arbitrary Hadoop `InputFormat` are given below. - -### SequenceFile and Hadoop InputFormats +* SequenceFile and Hadoop Input/Output Formats **Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. -#### Writable Support +**Writable Support** -PySpark SequenceFile support loads an RDD within Java, and pickles the resulting Java objects using -[Pyrolite](https://github.com/irmen/Pyrolite/). The following Writables are automatically converted: +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +Writables are automatically converted: @@ -403,32 +403,30 @@ PySpark SequenceFile support loads an RDD within Java, and pickles the resulting - - -
Writable TypePython Type
BooleanWritablebool
BytesWritablebytearray
NullWritableNone
ArrayWritablelist of primitives, or tuple of objects
MapWritabledict
Custom Class conforming to Java Bean conventionsdict of public properties (via JavaBean getters and setters) + __class__ for the class type
-#### Loading SequenceFiles +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Python `array.array` for arrays of primitive types, users need to specify custom converters. + +**Saving and Loading SequenceFiles** -Similarly to text files, SequenceFiles can be loaded by specifying the path. The key and value +Similarly to text files, SequenceFiles can be saved and loaded by specifying the path. The key and value classes can be specified, but for standard Writables this is not required. {% highlight python %} ->>> rdd = sc.sequenceFile("path/to/sequencefile/of/doubles") ->>> rdd.collect() # this example has DoubleWritable keys and Text values -[(1.0, u'aa'), - (2.0, u'bb'), - (2.0, u'aa'), - (3.0, u'cc'), - (2.0, u'bb'), - (1.0, u'aa')] +>>> rdd = sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) +>>> rdd.saveAsSequenceFile("path/to/file") +>>> sorted(sc.sequenceFile("path/to/file").collect()) +[(1, u'a'), (2, u'aa'), (3, u'aaa')] {% endhighlight %} -#### Loading Other Hadoop InputFormats +**Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat, for both 'new' and 'old' Hadoop APIs. If required, -a Hadoop configuration can be passed in as a Python dict. Here is an example using the +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: {% highlight python %} @@ -447,8 +445,7 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase) or custom -classes that don't conform to the JavaBean requirements, then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided for this. Simply extend this trait and implement your transformation code in the ```convert``` @@ -456,11 +453,8 @@ method. Remember to ensure that this class, along with any dependencies required classpath. See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/pythonconverters) -for examples of using HBase and Cassandra ```InputFormat```. - -Future support for writing data out as ```SequenceFileOutputFormat``` and other ```OutputFormats```, -is forthcoming. +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters. diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py new file mode 100644 index 0000000000000..1dfbf98604425 --- /dev/null +++ b/examples/src/main/python/cassandra_outputformat.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create data in Cassandra fist +(following: https://wiki.apache.org/cassandra/GettingStarted) + +cqlsh> CREATE KEYSPACE test + ... WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; +cqlsh> use test; +cqlsh:test> CREATE TABLE users ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + +> cassandra_outputformat test users 1745 john smith +> cassandra_outputformat test users 1744 john doe +> cassandra_outputformat test users 1746 john smith + +cqlsh:test> SELECT * FROM users; + + user_id | fname | lname +---------+-------+------- + 1745 | john | smith + 1744 | john | doe + 1746 | john | smith +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: cassandra_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py + Assumes you have created the following table in Cassandra already, + running on , in . + + cqlsh:> CREATE TABLE ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); + """ + exit(-1) + + host = sys.argv[1] + keyspace = sys.argv[2] + cf = sys.argv[3] + sc = SparkContext(appName="CassandraOutputFormat") + + conf = {"cassandra.output.thrift.address":host, + "cassandra.output.thrift.port":"9160", + "cassandra.output.keyspace":keyspace, + "cassandra.output.partitioner.class":"Murmur3Partitioner", + "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", + "mapreduce.output.basename":cf, + "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat", + "mapreduce.job.output.key.class":"java.util.Map", + "mapreduce.job.output.value.class":"java.util.List"} + key = {"user_id" : int(sys.argv[4])} + sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", + valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3289d9880a0f5..c9fa8e171c2a1 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -65,7 +65,8 @@ "org.apache.hadoop.hbase.mapreduce.TableInputFormat", "org.apache.hadoop.hbase.io.ImmutableBytesWritable", "org.apache.hadoop.hbase.client.Result", - valueConverter="org.apache.spark.examples.pythonconverters.HBaseConverter", + keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter", + valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter", conf=conf) output = hbase_rdd.collect() for (k, v) in output: diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py new file mode 100644 index 0000000000000..5e11548fd13f7 --- /dev/null +++ b/examples/src/main/python/hbase_outputformat.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create test table in HBase first: + +hbase(main):001:0> create 'test', 'f1' +0 row(s) in 0.7840 seconds + +> hbase_outputformat test row1 f1 q1 value1 +> hbase_outputformat test row2 f1 q1 value2 +> hbase_outputformat test row3 f1 q1 value3 +> hbase_outputformat test row4 f1 q1 value4 + +hbase(main):002:0> scan 'test' +ROW COLUMN+CELL + row1 column=f1:q1, timestamp=1405659615726, value=value1 + row2 column=f1:q1, timestamp=1405659626803, value=value2 + row3 column=f1:q1, timestamp=1405659640106, value=value3 + row4 column=f1:q1, timestamp=1405659650292, value=value4 +4 row(s) in 0.0780 seconds +""" +if __name__ == "__main__": + if len(sys.argv) != 7: + print >> sys.stderr, """ + Usage: hbase_outputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py + Assumes you have created
with column family in HBase running on already + """ + exit(-1) + + host = sys.argv[1] + table = sys.argv[2] + sc = SparkContext(appName="HBaseOutputFormat") + + conf = {"hbase.zookeeper.quorum": host, + "hbase.mapred.outputtable": table, + "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", + "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"} + + sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset( + conf=conf, + keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", + valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 29a65c7a5f295..83feb5703b908 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples.pythonconverters import org.apache.spark.api.python.Converter import java.nio.ByteBuffer import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions.{mapAsJavaMap, mapAsScalaMap} +import collection.JavaConversions._ /** @@ -44,3 +44,25 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) } } + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * Map[String, Int] to Cassandra key + */ +class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { + override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { + val input = obj.asInstanceOf[java.util.Map[String, Int]] + mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * List[String] to Cassandra value + */ +class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { + override def convert(obj: Any): java.util.List[ByteBuffer] = { + val input = obj.asInstanceOf[java.util.List[String]] + seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala deleted file mode 100644 index 42ae960bd64a1..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.pythonconverters - -import org.apache.spark.api.python.Converter -import org.apache.hadoop.hbase.client.Result -import org.apache.hadoop.hbase.util.Bytes - -/** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts a HBase Result - * to a String - */ -class HBaseConverter extends Converter[Any, String] { - override def convert(obj: Any): String = { - val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala new file mode 100644 index 0000000000000..273bee0a8b30f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import scala.collection.JavaConversions._ + +import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.client.{Put, Result} +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.util.Bytes + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * HBase Result to a String + */ +class HBaseResultToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val result = obj.asInstanceOf[Result] + Bytes.toStringBinary(result.value()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts an + * ImmutableBytesWritable to a String + */ +class ImmutableBytesWritableToStringConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val key = obj.asInstanceOf[ImmutableBytesWritable] + Bytes.toStringBinary(key.get()) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * String to an ImmutableBytesWritable + */ +class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBytesWritable] { + override def convert(obj: Any): ImmutableBytesWritable = { + val bytes = Bytes.toBytes(obj.asInstanceOf[String]) + new ImmutableBytesWritable(bytes) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a + * list of Strings to HBase Put + */ +class StringListToPutConverter extends Converter[Any, Put] { + override def convert(obj: Any): Put = { + val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val put = new Put(output(0)) + put.add(output(1), output(2), output(3)) + } +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 830a6ee03f2a6..7b0f8d83aedc5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -60,6 +60,7 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, @@ -378,7 +379,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None): + valueConverter=None, minSplits=None, batchSize=None): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -398,14 +399,18 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, @param valueConverter: @param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, minSplits, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -425,14 +430,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -449,14 +458,18 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -476,14 +489,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + valueClass, keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None): + valueConverter=None, conf=None, batchSize=None): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -500,11 +517,15 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @param valueConverter: (None by default) @param conf: Hadoop configuration, passed in as a dict (None by default) + @param batchSize: The number of Python objects represented as a single + Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) + batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) + ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf) - return RDD(jrdd, self, PickleSerializer()) + keyConverter, valueConverter, jconf, batchSize) + return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b84d976114f0d..e8fcc900efb24 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -231,6 +231,13 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() + def _toPickleSerialization(self): + if (self._jrdd_deserializer == PickleSerializer() or + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + return self + else: + return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def id(self): """ A unique ID for this RDD (within its SparkContext). @@ -1030,6 +1037,113 @@ def first(self): """ return self.take(1)[0] + def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, True) + + def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the new Hadoop OutputFormat API (mapreduce package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop job configuration, passed in as a dict (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + + def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Keys/values are + converted for output using either user specified converters or, by default, + L{org.apache.spark.api.python.JavaToWritableConverter}. + + @param conf: Hadoop job configuration, passed in as a dict + @param keyConverter: (None by default) + @param valueConverter: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + keyConverter, valueConverter, False) + + def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, + keyConverter=None, valueConverter=None, conf=None, + compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the old Hadoop OutputFormat API (mapred package). Key and value types + will be inferred if not specified. Keys and values are converted for output using either + user specified converters or L{org.apache.spark.api.python.JavaToWritableConverter}. The + C{conf} is applied on top of the base Hadoop conf associated with the SparkContext + of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. + + @param path: path to Hadoop file + @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.IntWritable", None by default) + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.Text", None by default) + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: (None by default) + @param compressionCodecClass: (None by default) + """ + jconf = self.ctx._dictToJavaMap(conf) + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, + jconf, compressionCodecClass) + + def saveAsSequenceFile(self, path, compressionCodecClass=None): + """ + Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file + system, using the L{org.apache.hadoop.io.Writable} types that we convert from the + RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. + 2. Keys and values of this Java RDD are converted to Writables and written out. + + @param path: path to sequence file + @param compressionCodecClass: (None by default) + """ + pickledRDD = self._toPickleSerialization() + batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + path, compressionCodecClass) + def saveAsPickleFile(self, path, batchSize=10): """ Save this RDD as a SequenceFile of serialized objects. The serializer diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8486c8595b5a4..c29deb9574ea2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,7 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ +from array import array from fileinput import input from glob import glob import os @@ -327,6 +328,17 @@ def test_sequencefiles(self): ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] self.assertEqual(doubles, ed) + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", "org.apache.hadoop.io.Text", "org.apache.hadoop.io.Text").collect()) @@ -353,14 +365,34 @@ def test_sequencefiles(self): maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable").collect()) - em = [(1, {2.0: u'aa'}), + em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), - (2, {3.0: u'bb'}), (3, {2.0: u'dd'})] self.assertEqual(maps, em) + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) @@ -369,6 +401,12 @@ def test_sequencefiles(self): u'double': 54.0, u'int': 123, u'str': u'test1'}) self.assertEqual(clazz[0], ec) + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) + self.assertEqual(unbatched_clazz[0], ec) + def test_oldhadoop(self): basepath = self.tempdir.name ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", @@ -379,10 +417,11 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.hadoopFile(hellopath, - "org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + oldconf = {"mapred.input.dir" : hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -397,10 +436,11 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - hello = self.sc.newAPIHadoopFile(hellopath, - "org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text").collect() + newconf = {"mapred.input.dir" : hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() result = [(0, u'Hello World!')] self.assertEqual(hello, result) @@ -435,16 +475,267 @@ def test_bad_inputs(self): "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text")) - def test_converter(self): + def test_converters(self): + # use of custom converters basepath = self.tempdir.name maps = sorted(self.sc.sequenceFile( basepath + "/sftestdata/sfmap/", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - valueConverter="org.apache.spark.api.python.TestConverter").collect()) - em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])] + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + +class TestOutputFormat(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) self.assertEqual(maps, em) + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1" : 1.0}), + (2, {"row2" : 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = sorted(self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect()) + self.assertEqual(result, dict_data) + + conf = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", + "mapred.output.dir" : basepath + "/olddataset/"} + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + old_dataset = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect()) + self.assertEqual(old_dataset, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir" : basepath + "/newdataset/"} + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = zip(x, y) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.outputformat.class" : + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", + "mapred.output.dir" : basepath + "/reserialize/newdataset"} + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_unbatched_save_and_read(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( + basepath + "/unbatched/") + + unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + batchSize=1).collect()) + self.assertEqual(unbatched_sequence, ei) + + unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopFile, ei) + + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopFile, ei) + + oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=oldconf, + batchSize=1).collect()) + self.assertEqual(unbatched_hadoopRDD, ei) + + newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=newconf, + batchSize=1).collect()) + self.assertEqual(unbatched_newAPIHadoopRDD, ei) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, numSlices=len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) class TestDaemon(unittest.TestCase): def connect(self, port): From 7c7ce54522015315c909e111d6c2cff83e9fb501 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 13:42:43 -0700 Subject: [PATCH 039/170] Wrap JAR_DL in dev/check-license. --- dev/check-license | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/check-license b/dev/check-license index 00bb20c133b7d..625ec161bc571 100755 --- a/dev/check-license +++ b/dev/check-license @@ -32,9 +32,9 @@ acquire_rat_jar () { printf "Attempting to fetch rat\n" JAR_DL=${JAR}.part if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 From 1097327538ec3870544f406775efcfe7722e48be Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 14:08:24 -0700 Subject: [PATCH 040/170] Set AMPLAB_JENKINS_BUILD_PROFILE. --- dev/run-tests | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dev/run-tests b/dev/run-tests index f2b523b996617..fb50fb380b15e 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,6 +21,18 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then + if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then + export SBT_MAVEN_PROFILES="-Dhadoop.version=1.0.4" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then + export SBT_MAVEN_PROFILES="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then + export SBT_MAVEN_PROFILES="-Pyarn -Dhadoop.version=2.2.0" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then + export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" + fi +fi + if [ -z "$SBT_MAVEN_PROFILES" ]; then export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi From 2f4b17056fdcba26fd3a7503b858364b883ab0b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 14:31:20 -0700 Subject: [PATCH 041/170] Properly pass SBT_MAVEN_PROFILES into sbt. --- dev/run-tests | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index fb50fb380b15e..daa85bc750c07 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -23,20 +23,20 @@ cd "$FWDIR" if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then - export SBT_MAVEN_PROFILES="-Dhadoop.version=1.0.4" + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then - export SBT_MAVEN_PROFILES="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then - export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi fi -if [ -z "$SBT_MAVEN_PROFILES" ]; then - export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi -echo "SBT_MAVEN_PROFILES=\"$SBT_MAVEN_PROFILES\"" +echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" # Remove work directory rm -rf ./work @@ -76,16 +76,15 @@ dev/scalastyle echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" + +if [ -n "$_RUN_SQL_TESTS" ]; then + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" +fi # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \ - assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" -else - echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" -fi +echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ + grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "=========================================================================" echo "Running PySpark tests" From 6ab96a6fd0db7731c8c5d6478d9e28b619581687 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 30 Jul 2014 15:04:33 -0700 Subject: [PATCH 042/170] SPARK-2749 [BUILD]. Spark SQL Java tests aren't compiling in Jenkins' Maven builds; missing junit:junit dep The Maven-based builds in the build matrix have been failing for a few days: https://amplab.cs.berkeley.edu/jenkins/view/Spark/ On inspection, it looks like the Spark SQL Java tests don't compile: https://amplab.cs.berkeley.edu/jenkins/view/Spark/job/Spark-Master-Maven-pre-YARN/hadoop.version=1.0.4,label=centos/244/consoleFull I confirmed it by repeating the command vs master: `mvn -Dhadoop.version=1.0.4 -Dlabel=centos -DskipTests clean package` The problem is that this module doesn't depend on JUnit. In fact, none of the modules do, but `com.novocode:junit-interface` (the SBT-JUnit bridge) pulls it in, in most places. However this module doesn't depend on `com.novocode:junit-interface` Adding the `junit:junit` dependency fixes the compile problem. In fact, the other modules with Java tests should probably depend on it explicitly instead of happening to get it via `com.novocode:junit-interface`, since that is a bit SBT/Scala-specific (and I am not even sure it's needed). Author: Sean Owen Closes #1660 from srowen/SPARK-2749 and squashes the following commits: 858ff7c [Sean Owen] Add explicit junit dep to other modules with Java tests for robustness 9636794 [Sean Owen] Add junit dep so that Spark SQL Java tests compile --- core/pom.xml | 5 +++++ external/flume/pom.xml | 5 +++++ external/kafka/pom.xml | 5 +++++ external/mqtt/pom.xml | 5 +++++ external/twitter/pom.xml | 5 +++++ external/zeromq/pom.xml | 5 +++++ extras/java8-tests/pom.xml | 5 +++++ mllib/pom.xml | 5 +++++ sql/core/pom.xml | 5 +++++ streaming/pom.xml | 5 +++++ 10 files changed, 50 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index 4f061099a477d..04d4b9cc1068e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -262,6 +262,11 @@ asm test + + junit + junit + test + com.novocode junit-interface diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 9f680b27c3308..c532705f3950c 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -72,6 +72,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 25a5c0a4d7d77..daf03360bc5f5 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -80,6 +80,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index f31ed655f6779..dc48a08c93de2 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -67,6 +67,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 56bb24c2a072e..b93ad016f84f0 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 54b0242c54e78..22c1fff23d9a2 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -62,6 +62,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3eade411b38b7..5308bb4e440ea 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -50,6 +50,11 @@ ${project.version} test-jar + + junit + junit + test + com.novocode junit-interface diff --git a/mllib/pom.xml b/mllib/pom.xml index f27cf520dc9fa..cb0fa7b97cb15 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -72,6 +72,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3a038a2db6173..c8016e41256d5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -68,6 +68,11 @@ jackson-databind 2.3.0 + + junit + junit + test + org.scalatest scalatest_${scala.binary.version} diff --git a/streaming/pom.xml b/streaming/pom.xml index b99f306b8f2cc..1072f74aea0d9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -58,6 +58,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + com.novocode junit-interface From 2ac37db7ac8f7ec5c99f3bfe459f8e2ac240961f Mon Sep 17 00:00:00 2001 From: Brock Noland Date: Wed, 30 Jul 2014 17:04:30 -0700 Subject: [PATCH 043/170] SPARK-2741 - Publish version of spark assembly which does not contain Hive Provide a version of the Spark tarball which does not package Hive. This is meant for HIve + Spark users. Author: Brock Noland Closes #1667 from brockn/master and squashes the following commits: 5beafb2 [Brock Noland] SPARK-2741 - Publish version of spark assembly which does not contain Hive --- dev/create-release/create-release.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 33de24d1ae6d7..af46572e6602b 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -115,6 +115,8 @@ make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4 make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" make_binary_release "hadoop2" \ "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" +make_binary_release "hadoop2-without-hive" \ + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" From 88a519db90d66ee5a1455ef4fcc1ad2a687e3d0b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 30 Jul 2014 17:30:51 -0700 Subject: [PATCH 044/170] [SPARK-2734][SQL] Remove tables from cache when DROP TABLE is run. Author: Michael Armbrust Closes #1650 from marmbrus/dropCached and squashes the following commits: e6ab80b [Michael Armbrust] Support if exists. 83426c6 [Michael Armbrust] Remove tables from cache when DROP TABLE is run. --- .../org/apache/spark/sql/hive/HiveQl.scala | 9 +++- .../spark/sql/hive/HiveStrategies.scala | 2 + .../spark/sql/hive/execution/DropTable.scala | 48 +++++++++++++++++++ .../spark/sql/hive/CachedTableSuite.scala | 16 +++++++ 4 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d18ccf8167487..3d2eb1eefaeda 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -44,6 +44,8 @@ private[hive] case class SourceCommand(filePath: String) extends Command private[hive] case class AddFile(filePath: String) extends Command +private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( @@ -96,7 +98,6 @@ private[hive] object HiveQl { "TOK_CREATEINDEX", "TOK_DROPDATABASE", "TOK_DROPINDEX", - "TOK_DROPTABLE", "TOK_MSCK", // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this. @@ -377,6 +378,12 @@ private[hive] object HiveQl { } protected def nodeToPlan(node: Node): LogicalPlan = node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", + Token("TOK_TABNAME", tableNameParts) :: + ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 4d0fab4140b21..2175c5f3835a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,8 @@ private[hive] trait HiveStrategies { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala new file mode 100644 index 0000000000000..9cd0c86c6c796 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext + +/** + * :: DeveloperApi :: + * Drops a table from the metastore and removes it if it is cached. + */ +@DeveloperApi +case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult: Seq[Any] = { + val ifExistsClause = if (ifExists) "IF EXISTS " else "" + hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") + hiveContext.catalog.unregisterTable(None, tableName) + Seq.empty + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3132d0112c708..08da6405a17c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive class CachedTableSuite extends HiveComparisonTest { + import TestHive._ + TestHive.loadTestTable("src") test("cache table") { @@ -32,6 +34,20 @@ class CachedTableSuite extends HiveComparisonTest { createQueryTest("read from cached table", "SELECT * FROM src LIMIT 1", reset = false) + test("Drop cached table") { + hql("CREATE TABLE test(a INT)") + cacheTable("test") + hql("SELECT * FROM test").collect() + hql("DROP TABLE test") + intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + hql("SELECT * FROM test").collect() + } + } + + test("DROP nonexistant table") { + hql("DROP TABLE IF EXISTS nonexistantTable") + } + test("check that table is cached and uncache") { TestHive.table("src").queryExecution.analyzed match { case _ : InMemoryRelation => // Found evidence of caching From e9b275b7697e7ad3b52b157d3274acc17ca8d828 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 30 Jul 2014 17:34:32 -0700 Subject: [PATCH 045/170] SPARK-2341 [MLLIB] loadLibSVMFile doesn't handle regression datasets Per discussion at https://issues.apache.org/jira/browse/SPARK-2341 , this is a look at deprecating the multiclass parameter. Thoughts welcome of course. Author: Sean Owen Closes #1663 from srowen/SPARK-2341 and squashes the following commits: 8a3abd7 [Sean Owen] Suppress MIMA error for removed package private classes 18a8c8e [Sean Owen] Updates from review 83d0092 [Sean Owen] Deprecated methods with multiclass, and instead always parse target as a double (ie. multiclass = true) --- .../examples/mllib/LinearRegression.scala | 2 +- .../examples/mllib/SparseNaiveBayes.scala | 4 +- .../spark/mllib/util/LabelParsers.scala | 56 ------------------- .../org/apache/spark/mllib/util/MLUtils.scala | 52 ++++++----------- .../spark/mllib/util/LabelParsersSuite.scala | 41 -------------- .../spark/mllib/util/MLUtilsSuite.scala | 14 ++--- project/MimaExcludes.scala | 8 +++ python/pyspark/mllib/util.py | 23 ++++---- 8 files changed, 46 insertions(+), 154 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 4811bb70e4b28..05b7d66f8dffd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -91,7 +91,7 @@ object LinearRegression extends App { Logger.getRootLogger.setLevel(Level.WARN) - val examples = MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache() + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 537e68a0991aa..88acd9dbb0878 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -22,7 +22,7 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.NaiveBayes -import org.apache.spark.mllib.util.{MLUtils, MulticlassLabelParser} +import org.apache.spark.mllib.util.MLUtils /** * An example naive Bayes app. Run with @@ -76,7 +76,7 @@ object SparseNaiveBayes { if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions val examples = - MLUtils.loadLibSVMFile(sc, params.input, multiclass = true, params.numFeatures, minPartitions) + MLUtils.loadLibSVMFile(sc, params.input, params.numFeatures, minPartitions) // Cache examples because it will be used in both training and evaluation. examples.cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala deleted file mode 100644 index e25bf18b780bf..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -/** Trait for label parsers. */ -private trait LabelParser extends Serializable { - /** Parses a string label into a double label. */ - def parse(labelString: String): Double -} - -/** Factory methods for label parsers. */ -private object LabelParser { - def getInstance(multiclass: Boolean): LabelParser = { - if (multiclass) MulticlassLabelParser else BinaryLabelParser - } -} - -/** - * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, - * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. - */ -private object BinaryLabelParser extends LabelParser { - /** Gets the default instance of BinaryLabelParser. */ - def getInstance(): LabelParser = this - - /** - * Parses the input label into positive (1.0) if the value is greater than 0.5, - * or negative (0.0) otherwise. - */ - override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 -} - -/** - * Label parser for multiclass labels, which converts the input label to double. - */ -private object MulticlassLabelParser extends LabelParser { - /** Gets the default instance of MulticlassLabelParser. */ - def getInstance(): LabelParser = this - - override def parse(labelString: String): Double = labelString.toDouble -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 30de24ad89f98..dc10a194783ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -55,7 +55,6 @@ object MLUtils { * * @param sc Spark context * @param path file or directory path in any Hadoop-supported file system URI - * @param labelParser parser for labels * @param numFeatures number of features, which will be determined from the input data if a * nonpositive value is given. This is useful when the dataset is already split * into multiple files and you want to load them separately, because some @@ -64,10 +63,9 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ - private def loadLibSVMFile( + def loadLibSVMFile( sc: SparkContext, path: String, - labelParser: LabelParser, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { val parsed = sc.textFile(path, minPartitions) @@ -75,7 +73,7 @@ object MLUtils { .filter(line => !(line.isEmpty || line.startsWith("#"))) .map { line => val items = line.split(' ') - val label = labelParser.parse(items.head) + val label = items.head.toDouble val (indices, values) = items.tail.map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. @@ -102,64 +100,46 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. - * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. - * Each line represents a labeled sparse feature vector using the following format: - * {{{label index1:value1 index2:value2 ...}}} - * where the indices are one-based and in ascending order. - * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], - * where the feature indices are converted to zero-based. - * - * @param sc Spark context - * @param path file or directory path in any Hadoop-supported file system URI - * @param multiclass whether the input labels contain more than two classes. If false, any label - * with value greater than 0.5 will be mapped to 1.0, or 0.0 otherwise. So it - * works for both +1/-1 and 1/0 cases. If true, the double value parsed directly - * from the label string will be used as the label value. - * @param numFeatures number of features, which will be determined from the input data if a - * nonpositive value is given. This is useful when the dataset is already split - * into multiple files and you want to load them separately, because some - * features may not present in certain files, which leads to inconsistent - * feature dimensions. - * @param minPartitions min number of partitions - * @return labeled data stored as an RDD[LabeledPoint] - */ - def loadLibSVMFile( + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") + def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, LabelParser.getInstance(multiclass), numFeatures, minPartitions) + loadLibSVMFile(sc, path, numFeatures, minPartitions) /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + def loadLibSVMFile( + sc: SparkContext, + path: String, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean, numFeatures: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, numFeatures, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, numFeatures) - /** - * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the number of features - * determined automatically and the default number of partitions. - */ + @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, path: String, multiclass: Boolean): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path) /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, multiclass = false, -1, sc.defaultMinPartitions) + loadLibSVMFile(sc, path, -1) /** * Save labeled data in LIBSVM format. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala deleted file mode 100644 index ac85677f2f014..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -import org.scalatest.FunSuite - -class LabelParsersSuite extends FunSuite { - test("binary label parser") { - for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) { - assert(parser.parse("+1") === 1.0) - assert(parser.parse("1") === 1.0) - assert(parser.parse("0") === 0.0) - assert(parser.parse("-1") === 0.0) - } - } - - test("multiclass label parser") { - for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) { - assert(parser.parse("0") == 0.0) - assert(parser.parse("+1") === 1.0) - assert(parser.parse("1") === 1.0) - assert(parser.parse("2") === 2.0) - assert(parser.parse("3") === 3.0) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index c14870fb969a8..8ef2bb1bf6a78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -63,9 +63,9 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { test("loadLibSVMFile") { val lines = """ - |+1 1:1.0 3:2.0 5:3.0 - |-1 - |-1 2:4.0 4:5.0 6:6.0 + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 """.stripMargin val tempDir = Files.createTempDir() tempDir.deleteOnExit() @@ -73,7 +73,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString - val pointsWithNumFeatures = loadLibSVMFile(sc, path, multiclass = false, 6).collect() + val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { @@ -86,11 +86,11 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) } - val multiclassPoints = loadLibSVMFile(sc, path, multiclass = true).collect() + val multiclassPoints = loadLibSVMFile(sc, path).collect() assert(multiclassPoints.length === 3) assert(multiclassPoints(0).label === 1.0) - assert(multiclassPoints(1).label === -1.0) - assert(multiclassPoints(2).label === -1.0) + assert(multiclassPoints(1).label === 0.0) + assert(multiclassPoints(2).label === 0.0) Utils.deleteRecursively(tempDir) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5ff88f0dd1cac..5a835f58207cf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -97,6 +97,14 @@ object MimaExcludes { "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq ( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) case v if v.startsWith("1.0") => Seq( diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index a707a9dcd5b49..d94900cefdb77 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -29,15 +29,18 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ + @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + return _parse_libsvm_line(line) + + @staticmethod + def _parse_libsvm_line(line): """ Parses a line in LIBSVM format into (label, indices, values). """ items = line.split(None) label = float(items[0]) - if not multiclass: - label = 1.0 if label > 0.5 else 0.0 nnz = len(items) - 1 indices = np.zeros(nnz, dtype=np.int32) values = np.zeros(nnz) @@ -64,8 +67,13 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) + @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + return loadLibSVMFile(sc, path, numFeatures, minPartitions) + + @staticmethod + def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -81,13 +89,6 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non @param sc: Spark context @param path: file or directory path in any Hadoop-supported file system URI - @param multiclass: whether the input labels contain more than - two classes. If false, any label with value - greater than 0.5 will be mapped to 1.0, or - 0.0 otherwise. So it works for both +1/-1 and - 1/0 cases. If true, the double value parsed - directly from the label string will be used - as the label value. @param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is @@ -105,7 +106,7 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() + >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -124,7 +125,7 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non """ lines = sc.textFile(path, minPartitions) - parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l, multiclass)) + parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 From da501766834453c9ac7095c7e8c930151f87cf11 Mon Sep 17 00:00:00 2001 From: strat0sphere Date: Wed, 30 Jul 2014 17:57:50 -0700 Subject: [PATCH 046/170] Update DecisionTreeRunner.scala Author: strat0sphere Closes #1676 from strat0sphere/patch-1 and squashes the following commits: 044d2fa [strat0sphere] Update DecisionTreeRunner.scala --- .../org/apache/spark/examples/mllib/DecisionTreeRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 43f13fe24f0d0..6db9bf3cf5be6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD /** * An example runner for decision tree. Run with * {{{ - * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] + * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ From e966284409f9355e1169960e73a2215617c8cb22 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 30 Jul 2014 18:07:59 -0700 Subject: [PATCH 047/170] SPARK-2045 Sort-based shuffle This adds a new ShuffleManager based on sorting, as described in https://issues.apache.org/jira/browse/SPARK-2045. The bulk of the code is in an ExternalSorter class that is similar to ExternalAppendOnlyMap, but sorts key-value pairs by partition ID and can be used to create a single sorted file with a map task's output. (Longer-term I think this can take on the remaining functionality in ExternalAppendOnlyMap and replace it so we don't have code duplication.) The main TODOs still left are: - [x] enabling ExternalSorter to merge across spilled files - [x] with an Ordering - [x] without an Ordering, using the keys' hash codes - [x] adding more tests (e.g. a version of our shuffle suite that runs on this) - [x] rebasing on top of the size-tracking refactoring in #1165 when that is merged - [x] disabling spilling if spark.shuffle.spill is set to false Despite this though, this seems to work pretty well (running successfully in cases where the hash shuffle would OOM, such as 1000 reduce tasks on executors with only 1G memory), and it seems to be comparable in speed or faster than hash-based shuffle (it will create much fewer files for the OS to keep track of). So I'm posting it to get some early feedback. After these TODOs are done, I'd also like to enable ExternalSorter to sort data within each partition by a key as well, which will allow us to use it to implement external spilling in reduce tasks in `sortByKey`. Author: Matei Zaharia Closes #1499 from mateiz/sort-based-shuffle and squashes the following commits: bd841f9 [Matei Zaharia] Various review comments d1c137fd [Matei Zaharia] Various review comments a611159 [Matei Zaharia] Compile fixes due to rebase 62c56c8 [Matei Zaharia] Fix ShuffledRDD sometimes not returning Tuple2s. f617432 [Matei Zaharia] Fix a failing test (seems to be due to change in SizeTracker logic) 9464d5f [Matei Zaharia] Simplify code and fix conflicts after latest rebase 0174149 [Matei Zaharia] Add cleanup behavior and cleanup tests for sort-based shuffle eb4ee0d [Matei Zaharia] Remove customizable element type in ShuffledRDD fa2e8db [Matei Zaharia] Allow nextBatchStream to be called after we're done looking at all streams a34b352 [Matei Zaharia] Fix tracking of indices within a partition in SpillReader, and add test 03e1006 [Matei Zaharia] Add a SortShuffleSuite that runs ShuffleSuite with sort-based shuffle 3c7ff1f [Matei Zaharia] Obey the spark.shuffle.spill setting in ExternalSorter ad65fbd [Matei Zaharia] Rebase on top of Aaron's Sorter change, and use Sorter in our buffer 44d2a93 [Matei Zaharia] Use estimateSize instead of atGrowThreshold to test collection sizes 5686f71 [Matei Zaharia] Optimize merging phase for in-memory only data: 5461cbb [Matei Zaharia] Review comments and more tests (e.g. tests with 1 element per partition) e9ad356 [Matei Zaharia] Update ContextCleanerSuite to make sure shuffle cleanup tests use hash shuffle (since they were written for it) c72362a [Matei Zaharia] Added bug fix and test for when iterators are empty de1fb40 [Matei Zaharia] Make trait SizeTrackingCollection private[spark] 4988d16 [Matei Zaharia] tweak c1b7572 [Matei Zaharia] Small optimization ba7db7f [Matei Zaharia] Handle null keys in hash-based comparator, and add tests for collisions ef4e397 [Matei Zaharia] Support for partial aggregation even without an Ordering 4b7a5ce [Matei Zaharia] More tests, and ability to sort data if a total ordering is given e1f84be [Matei Zaharia] Fix disk block manager test 5a40a1c [Matei Zaharia] More tests 614f1b4 [Matei Zaharia] Add spill metrics to map tasks cc52caf [Matei Zaharia] Add more error handling and tests for error cases bbf359d [Matei Zaharia] More work 3a56341 [Matei Zaharia] More partial work towards sort-based shuffle 7a0895d [Matei Zaharia] Some more partial work towards sort-based shuffle b615476 [Matei Zaharia] Scaffolding for sort-based shuffle --- .../scala/org/apache/spark/Aggregator.scala | 24 +- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../apache/spark/api/java/JavaPairRDD.scala | 2 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 7 +- .../spark/rdd/OrderedRDDFunctions.scala | 14 +- .../apache/spark/rdd/PairRDDFunctions.scala | 4 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 17 +- .../shuffle/hash/HashShuffleManager.scala | 2 +- .../shuffle/hash/HashShuffleReader.scala | 5 +- .../shuffle/hash/HashShuffleWriter.scala | 6 +- .../shuffle/sort/SortShuffleManager.scala | 80 +++ .../shuffle/sort/SortShuffleWriter.scala | 165 +++++ .../org/apache/spark/storage/BlockId.scala | 11 +- .../spark/storage/DiskBlockManager.scala | 38 +- .../spark/storage/ShuffleBlockManager.scala | 29 +- .../collection/ExternalAppendOnlyMap.scala | 36 +- .../util/collection/ExternalSorter.scala | 662 ++++++++++++++++++ .../SizeTrackingAppendOnlyMap.scala | 5 +- .../collection/SizeTrackingPairBuffer.scala | 86 +++ .../SizeTrackingPairCollection.scala | 34 + .../org/apache/spark/CheckpointSuite.scala | 2 +- .../apache/spark/ContextCleanerSuite.scala | 186 +++-- .../org/apache/spark/ShuffleNettySuite.scala | 2 +- .../scala/org/apache/spark/ShuffleSuite.scala | 26 +- .../org/apache/spark/SortShuffleSuite.scala | 34 + .../scala/org/apache/spark/rdd/RDDSuite.scala | 6 +- .../ExternalAppendOnlyMapSuite.scala | 25 +- .../util/collection/ExternalSorterSuite.scala | 566 +++++++++++++++ .../util/collection/FixedHashObject.scala | 25 + .../graphx/impl/MessageToPartition.scala | 2 +- .../graphx/impl/RoutingTablePartition.scala | 2 +- project/SparkBuild.scala | 1 + .../apache/spark/sql/execution/Exchange.scala | 6 +- .../spark/sql/execution/basicOperators.scala | 2 +- 35 files changed, 1969 insertions(+), 159 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala create mode 100644 core/src/test/scala/org/apache/spark/SortShuffleSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ff0ca11749d42..79c9c451d273d 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -56,18 +56,23 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") - def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = { + def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) + : Iterator[(K, C)] = + { if (!externalSorting) { val combiners = new AppendOnlyMap[K,C] var kc: Product2[K, C] = null @@ -85,9 +90,12 @@ case class Aggregator[K, V, C] ( val pair = iter.next() combiners.insert(pair._1, pair._2) } - // TODO: Make this non optional in a future release - Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) - Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled) + // Update task metrics if context is not null + // TODO: Make context non-optional in a future release + Option(context).foreach { c => + c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled + c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled + } combiners.iterator } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fb4c86716bb8d..b25f081761a64 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => executorEnvs("SPARK_PREPEND_CLASSES") = v } // The Mesos scheduler backend relies on this environment variable to set executor memory. @@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging { /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) - * If checkSerializable is set, clean will also proactively - * check to see if f is serializable and throw a SparkException + * If checkSerializable is set, clean will also proactively + * check to see if f is serializable and throw a SparkException * if not. - * + * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 31bf8dced2638..47708cb2e78bd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 6388ef82cc5db..fabb882cdd4b3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,10 +17,11 @@ package org.apache.spark.rdd +import scala.language.existentials + import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled + context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index d85f962783931..e98bad2026e32 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Logging, RangePartitioner} +import org.apache.spark.annotation.DeveloperApi /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner} */ class OrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag, - P <: Product2[K, V] : ClassTag]( + P <: Product2[K, V] : ClassTag] @DeveloperApi() ( self: RDD[P]) - extends Logging with Serializable { - + extends Logging with Serializable +{ private val ordering = implicitly[Ordering[K]] /** @@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in * order of the keys). */ - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { + // TODO: this currently doesn't work on P other than Tuple2! + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + : RDD[(K, V)] = + { val part = new RangePartitioner(numPartitions, self, ascending) - new ShuffledRDD[K, V, V, P](self, part) + new ShuffledRDD[K, V, V](self, part) .setKeyOrdering(if (ascending) ordering else ordering.reverse) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 1af4e5f0b6d08..93af50c0a9cd1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { - new ShuffledRDD[K, V, C, (K, C)](self, partitioner) + new ShuffledRDD[K, V, C](self, partitioner) .setSerializer(serializer) .setAggregator(aggregator) .setMapSideCombine(mapSideCombine) @@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (self.partitioner == Some(partitioner)) { self } else { - new ShuffledRDD[K, V, V, (K, V)](self, partitioner) + new ShuffledRDD[K, V, V](self, partitioner) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 726b3f2bbeea7..74ac97091fd0b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag]( val distributePartition = (index: Int, items: Iterator[T]) => { var position = (new Random(index)).nextInt(numPartitions) items.map { t => - // Note that the hash code of the key will just be the key itself. The HashPartitioner + // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. position = position + 1 (position, t) @@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition), + new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), numPartitions).values } else { @@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, - fraction: Double, + def sample(withReplacement: Boolean, + fraction: Double, seed: Long = Utils.random.nextLong): RDD[T] = { require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index bf02f68d0d3d3..d9fe6847254fa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @tparam V the value class. * @tparam C the combiner class. */ +// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( +class ShuffledRDD[K, V, C]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) - extends RDD[P](prev.context, Nil) { + extends RDD[(K, C)](prev.context, Nil) { private var serializer: Option[Serializer] = None @@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( private var mapSideCombine: Boolean = false /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = { + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = { this.serializer = Option(serializer) this } /** Set key ordering for RDD's shuffle. */ - def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = { + def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = { this.keyOrdering = Option(keyOrdering) this } /** Set aggregator for RDD's shuffle. */ - def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = { + def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = { this.aggregator = Option(aggregator) this } /** Set mapSideCombine flag for RDD's shuffle. */ - def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = { + def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = { this.mapSideCombine = mapSideCombine this } @@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def compute(split: Partition, context: TaskContext): Iterator[P] = { + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() - .asInstanceOf[Iterator[P]] + .asInstanceOf[Iterator[(K, C)]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 5b0940ecce29d..df98d18fa8193 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,7 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index c8059496a1bdf..e32ad9c036ad4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -class HashShuffleReader[K, C]( +private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, @@ -47,7 +47,8 @@ class HashShuffleReader[K, C]( } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { throw new IllegalStateException("Aggregator is empty for map-side combine") } else { - iter + // Convert the Product2s to pairs since this is what downstream RDDs currently expect + iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 9b78228519da4..1923f7c71a48f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -class HashShuffleWriter[K, V]( +private[spark] class HashShuffleWriter[K, V]( handle: BaseShuffleHandle[K, V, _], mapId: Int, context: TaskContext) @@ -33,6 +33,10 @@ class HashShuffleWriter[K, V]( private val dep = handle.dependency private val numOutputSplits = dep.partitioner.numPartitions private val metrics = context.taskMetrics + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. private var stopping = false private val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala new file mode 100644 index 0000000000000..6dcca47ea7c0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{DataInputStream, FileInputStream} + +import org.apache.spark.shuffle._ +import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.shuffle.hash.HashShuffleReader +import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId} + +private[spark] class SortShuffleManager extends ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + // We currently use the same block store shuffle fetcher as the hash-based shuffle. + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} + + /** Get the location of a block in a map output file. Uses the index file we create for it. */ + def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = { + // The block is actually going to be a range of a single map output file for this map, so + // figure out the ID of the consolidated file, then the offset within that from our index + val consolidatedId = blockId.copy(reduceId = 0) + val indexFile = diskManager.getFile(consolidatedId.name + ".index") + val in = new DataInputStream(new FileInputStream(indexFile)) + try { + in.skip(blockId.reduceId * 8) + val offset = in.readLong() + val nextOffset = in.readLong() + new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset) + } finally { + in.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala new file mode 100644 index 0000000000000..42fcd07fa18bc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream} + +import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter + +private[spark] class SortShuffleWriter[K, V, C]( + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numPartitions = dep.partitioner.numPartitions + + private val blockManager = SparkEnv.get.blockManager + private val ser = Serializer.getSerializer(dep.serializer.orNull) + + private val conf = SparkEnv.get.conf + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + + private var sorter: ExternalSorter[K, V, _] = null + private var outputFile: File = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + // Get an iterator with the elements for each partition ID + val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") + } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we + // don't care whether the keys get sorted in each partition; that will be done on the + // reduce side if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.write(records) + sorter.partitionedIterator + } + } + + // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later + // serve different ranges of this file using an index file that we create at the end. + val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) + outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + // Statistics + var totalBytes = 0L + var totalTime = 0L + + for ((id, elements) <- partitions) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) + for (elem <- elements) { + writer.write(elem) + } + writer.commit() + writer.close() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + totalTime += writer.timeWriting() + totalBytes += segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + // Register our map output with the ShuffleBlockManager, which handles cleaning it over time + blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) + + mapStatus = new MapStatus(blockManager.blockManagerId, + lengths.map(MapOutputTracker.compressSize)) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + // The map task failed, so delete our output file if we created one + if (outputFile != null) { + outputFile.delete() + } + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + sorter.stop() + sorter = null + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 42ec181b00bb3..c1756ac905417 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -54,11 +54,15 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { } @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) - extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) @@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -99,6 +104,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => + ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2e7ed7538e6e5..4d66ccea211fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,10 +21,11 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.Logging +import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.SortShuffleManager /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -34,11 +35,13 @@ import org.apache.spark.util.Utils * * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) +private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) + + private val subDirsPerLocalDir = + shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD addShutdownHook() /** - * Returns the physical file segment in which the given BlockId is located. - * If the BlockId has been mapped to a specific FileSegment, that will be returned. - * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + * Returns the physical file segment in which the given BlockId is located. If the BlockId has + * been mapped to a specific FileSegment by the shuffle layer, that will be returned. + * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { - shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) + val env = SparkEnv.get // NOTE: can be null in unit tests + if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) { + // For sort-based shuffle, let it figure out its blocks + val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager] + sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this) + } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) { + // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this + shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) } else { val file = getFile(blockId.name) new FileSegment(file, 0, file.length()) @@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD getBlockLocation(blockId).file.exists() } - /** List all the blocks currently stored on disk by the disk manager. */ - def getAllBlocks(): Seq[BlockId] = { + /** List all the files currently stored on disk by the disk manager. */ + def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories subDirs.flatten.filter(_ != null).flatMap { dir => - val files = dir.list() + val files = dir.listFiles() if (files != null) files else Seq.empty - }.map(BlockId.apply) + } + } + + /** List all the blocks currently stored on disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + getAllFiles().map(f => BlockId(f.getName)) } /** Produces a unique block id and File suitable for intermediate results. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 35910e552fe86..7beb55c411e71 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.shuffle.sort.SortShuffleManager /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ +// TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] class ShuffleBlockManager(blockManager: BlockManager) extends Logging { def conf = blockManager.conf @@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) + // Are we using sort-based shuffle? + val sortBasedShuffle = + conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 /** @@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + /** + * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle + * because it just writes a single file by itself. + */ + def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + val shuffleState = shuffleStates(shuffleId) + shuffleState.completedMapTasks.add(mapId) + } + + /** + * Get a ShuffleWriterGroup for the given map task, which will register it as complete + * when the writers are closed successfully + */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) @@ -182,7 +202,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { + if (sortBasedShuffle) { + // There's a single block ID for each map, plus an index file for it + for (mapId <- state.completedMapTasks) { + val blockId = new ShuffleBlockId(shuffleId, mapId, 0) + blockManager.diskBlockManager.getFile(blockId).delete() + blockManager.diskBlockManager.getFile(blockId.name + ".index").delete() + } + } else if (consolidateShuffleFiles) { for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { file.delete() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 6f263c39d1435..b34512ef9eb60 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -79,12 +79,16 @@ class ExternalAppendOnlyMap[K, V, C]( (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - // Number of pairs in the in-memory map - private var numPairsInMemory = 0L + // Number of pairs inserted since last spill; note that we count them even if a value is merged + // with a previous key in case we're doing something like groupBy where the result grows + private var elementsRead = 0L // Number of in-memory pairs inserted before tracking the map's shuffle memory usage private val trackMemoryThreshold = 1000 + // How much of the shared memory pool this collection has claimed + private var myMemoryThreshold = 0L + /** * Size of object batches when reading/writing from serializers. * @@ -106,7 +110,6 @@ class ExternalAppendOnlyMap[K, V, C]( private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() - private val threadId = Thread.currentThread().getId /** * Insert the given key and value into the map. @@ -134,31 +137,35 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { - val mapSize = currentMap.estimateSize() + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + currentMap.estimateSize() >= myMemoryThreshold) + { + val currentSize = currentMap.estimateSize() var shouldSpill = false val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap // Atomically check whether there is sufficient memory in the global pool for // this map to grow and, if possible, allocate the required amount shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) val availableMemory = maxMemoryThreshold - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - // Assume map growth factor is 2x - shouldSpill = availableMemory < mapSize * 2 + // Try to allocate at least 2x more memory, otherwise spill + shouldSpill = availableMemory < currentSize * 2 if (!shouldSpill) { - shuffleMemoryMap(threadId) = mapSize * 2 + shuffleMemoryMap(threadId) = currentSize * 2 + myMemoryThreshold = currentSize * 2 } } // Do not synchronize spills if (shouldSpill) { - spill(mapSize) + spill(currentSize) } } currentMap.changeValue(curEntry._1, update) - numPairsInMemory += 1 + elementsRead += 1 } } @@ -178,9 +185,10 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - private def spill(mapSize: Long) { + private def spill(mapSize: Long): Unit = { spillCount += 1 - logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) @@ -227,7 +235,9 @@ class ExternalAppendOnlyMap[K, V, C]( shuffleMemoryMap.synchronized { shuffleMemoryMap(Thread.currentThread().getId) = 0 } - numPairsInMemory = 0 + myMemoryThreshold = 0 + + elementsRead = 0 _memoryBytesSpilled += mapSize } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala new file mode 100644 index 0000000000000..54c3310744136 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import com.google.common.io.ByteStreams + +import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.BlockId + +/** + * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner + * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then + * optionally sorts keys within each partition using a custom Comparator. Can output a single + * partitioned file with a different byte range for each partition, suitable for shuffle fetches. + * + * If combining is disabled, the type C must equal V -- we'll cast the objects at the end. + * + * @param aggregator optional Aggregator with combine functions to use for merging data + * @param partitioner optional Partitioner; if given, sort by partition ID and then key + * @param ordering optional Ordering to sort keys within each partition; should be a total ordering + * @param serializer serializer to use when spilling to disk + * + * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really + * want the output keys to be sorted. In a map task without map-side combine for example, you + * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do + * want to do combining, having an Ordering is more efficient than not having it. + * + * At a high level, this class works as follows: + * + * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if + * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, + * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to + * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner). + * + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. + * + * - When the user requests an iterator, the spilled files are merged, along with any remaining + * in-memory data, using the same sort order defined above (unless both sorting and aggregation + * are disabled). If we need to aggregate by key, we either use a total ordering from the + * ordering parameter, or read the keys with the same hash code and compare them with each other + * for equality to merge values. + * + * - Users are expected to call stop() at the end to delete all the intermediate files. + */ +private[spark] class ExternalSorter[K, V, C]( + aggregator: Option[Aggregator[K, V, C]] = None, + partitioner: Option[Partitioner] = None, + ordering: Option[Ordering[K]] = None, + serializer: Option[Serializer] = None) extends Logging { + + private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) + private val shouldPartition = numPartitions > 1 + + private val blockManager = SparkEnv.get.blockManager + private val diskBlockManager = blockManager.diskBlockManager + private val ser = Serializer.getSerializer(serializer) + private val serInstance = ser.newInstance() + + private val conf = SparkEnv.get.conf + private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + + // Size of object batches when reading/writing from serializers. + // + // Objects are written in batches, with each batch using its own serialization stream. This + // cuts down on the size of reference-tracking maps constructed when deserializing a stream. + // + // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers + // grow internal data structures by growing + copying every time the number of objects doubles. + private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) + + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + // Data structures to store in-memory objects before we spill. Depending on whether we have an + // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we + // store them in an array buffer. + private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] + private var buffer = new SizeTrackingPairBuffer[(Int, K), C] + + // Number of pairs read from input since last spill; note that we count them even if a value is + // merged with a previous key in case we're doing something like groupBy where the result grows + private var elementsRead = 0L + + // What threshold of elementsRead we start estimating map size at. + private val trackMemoryThreshold = 1000 + + // Spilling statistics + private var spillCount = 0 + private var _memoryBytesSpilled = 0L + private var _diskBytesSpilled = 0L + + // Collective memory threshold shared across all running tasks + private val maxMemoryThreshold = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + // How much of the shared memory pool this collection has claimed + private var myMemoryThreshold = 0L + + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. + // Can be a partial ordering by hash code if a total ordering is not provided through by the + // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some + // non-equal keys also have this, so we need to do a later pass to find truly equal keys). + // Note that we ignore this if no aggregator and no ordering are given. + private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] { + override def compare(a: K, b: K): Int = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + }) + + // A comparator for (Int, K) elements that orders them by partition and then possibly by key + private val partitionKeyComparator: Comparator[(Int, K)] = { + if (ordering.isDefined || aggregator.isDefined) { + // Sort by partition ID then key comparator + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.compare(a._2, b._2) + } + } + } + } else { + // Just sort it by partition ID + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + } + } + + // Information about a spilled file. Includes sizes in bytes of "batches" written by the + // serializer as we periodically reset its stream, as well as number of elements in each + // partition, used to efficiently keep track of partitions when merging. + private[this] case class SpilledFile( + file: File, + blockId: BlockId, + serializerBatchSizes: Array[Long], + elementsPerPartition: Array[Long]) + private val spills = new ArrayBuffer[SpilledFile] + + def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + // TODO: stop combining if we find that the reduction factor isn't high + val shouldCombine = aggregator.isDefined + + if (shouldCombine) { + // Combine values in-memory first using our AppendOnlyMap + val mergeValue = aggregator.get.mergeValue + val createCombiner = aggregator.get.createCombiner + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (records.hasNext) { + elementsRead += 1 + kv = records.next() + map.changeValue((getPartition(kv._1), kv._1), update) + maybeSpill(usingMap = true) + } + } else { + // Stick values into our buffer + while (records.hasNext) { + elementsRead += 1 + val kv = records.next() + buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + maybeSpill(usingMap = false) + } + } + } + + /** + * Spill the current in-memory collection to disk if needed. + * + * @param usingMap whether we're using a map or buffer as our current in-memory collection + */ + private def maybeSpill(usingMap: Boolean): Unit = { + if (!spillingEnabled) { + return + } + + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + + // TODO: factor this out of both here and ExternalAppendOnlyMap + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + collection.estimateSize() >= myMemoryThreshold) + { + // TODO: This logic doesn't work if there are two external collections being used in the same + // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] + + val currentSize = collection.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // us to double our threshold + shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId + val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) + + // Try to allocate at least 2x more memory, otherwise spill + shouldSpill = availableMemory < currentSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = currentSize * 2 + myMemoryThreshold = currentSize * 2 + } + } + // Do not hold lock during spills + if (shouldSpill) { + spill(currentSize, usingMap) + } + } + } + + /** + * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. + * + * @param usingMap whether we're using a map or buffer as our current in-memory collection + */ + private def spill(memorySize: Long, usingMap: Boolean): Unit = { + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + val memorySize = collection.estimateSize() + + spillCount += 1 + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" + .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + val (blockId, file) = diskBlockManager.createTempBlock() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + var objectsWritten = 0 // Objects written since the last flush + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // How many elements we have in each partition + val elementsPerPartition = new Array[Long](numPartitions) + + // Flush the disk writer's contents to disk, and update relevant variables + def flush() = { + writer.commit() + val bytesWritten = writer.bytesWritten + batchSizes.append(bytesWritten) + _diskBytesSpilled += bytesWritten + objectsWritten = 0 + } + + try { + val it = collection.destructiveSortedIterator(partitionKeyComparator) + while (it.hasNext) { + val elem = it.next() + val partitionId = elem._1._1 + val key = elem._1._2 + val value = elem._2 + writer.write(key) + writer.write(value) + elementsPerPartition(partitionId) += 1 + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + writer.close() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + } + } + if (objectsWritten > 0) { + flush() + } + writer.close() + } catch { + case e: Exception => + writer.close() + file.delete() + throw e + } + + if (usingMap) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } else { + buffer = new SizeTrackingPairBuffer[(Int, K), C] + } + + // Reset the amount of shuffle memory used by this map in the global pool + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap(Thread.currentThread().getId) = 0 + } + myMemoryThreshold = 0 + + spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + _memoryBytesSpilled += memorySize + } + + /** + * Merge a sequence of sorted files, giving an iterator over partitions and then over elements + * inside each partition. This can be used to either write out a new file or return data to + * the user. + * + * Returns an iterator over all the data written to this object, grouped by partition. For each + * partition we then have an iterator over its contents, and these are expected to be accessed + * in order (you can't "skip ahead" to one partition without reading the previous one). + * Guaranteed to return a key-value pair for each partition, in order of partition ID. + */ + private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = { + val readers = spills.map(new SpillReader(_)) + val inMemBuffered = inMemory.buffered + (0 until numPartitions).iterator.map { p => + val inMemIterator = new IteratorForPartition(p, inMemBuffered) + val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) + if (aggregator.isDefined) { + // Perform partial aggregation across partitions + (p, mergeWithAggregation( + iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined)) + } else if (ordering.isDefined) { + // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); + // sort the elements without trying to merge them + (p, mergeSort(iterators, ordering.get)) + } else { + (p, iterators.iterator.flatten) + } + } + } + + /** + * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. + */ + private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = + { + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + // Use the reverse of comparator.compare because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + /** + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + private def mergeWithAggregation( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K], + totalOrder: Boolean) + : Iterator[Product2[K, C]] = + { + if (!totalOrder) { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to read all keys considered equal by the ordering at once and compare them. + new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + keys.clear() + combiners.clear() + val firstPair = sorted.next() + keys += firstPair._1 + combiners += firstPair._2 + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val pair = sorted.next() + var i = 0 + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true + } + i += 1 + } + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } + } + + // Note that we return an iterator of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + keys.iterator.zip(combiners.iterator) + } + }.flatMap(i => i) + } else { + // We have a total ordering, so the objects with the same key are sequential. + new Iterator[Product2[K, C]] { + val sorted = mergeSort(iterators, comparator).buffered + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = sorted.next() + val k = elem._1 + var c = elem._2 + while (sorted.hasNext && sorted.head._1 == k) { + c = mergeCombiners(c, sorted.head._2) + } + (k, c) + } + } + } + } + + /** + * An internal class for reading a spilled file partition by partition. Expects all the + * partitions to be requested in order. + */ + private[this] class SpillReader(spill: SpilledFile) { + val fileStream = new FileInputStream(spill.file) + val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + + // Track which partition and which batch stream we're in. These will be the indices of + // the next element we will read. We'll also store the last partition read so that + // readNextPartition() can figure out what partition that was from. + var partitionId = 0 + var indexInPartition = 0L + var batchStreamsRead = 0 + var indexInBatch = 0 + var lastPartitionId = 0 + + skipToNextPartition() + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + var batchStream = nextBatchStream() + var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) + var deserStream = serInstance.deserializeStream(compressedStream) + var nextItem: (K, C) = null + var finished = false + + /** Construct a stream that only reads from the next batch */ + def nextBatchStream(): InputStream = { + if (batchStreamsRead < spill.serializerBatchSizes.length) { + batchStreamsRead += 1 + ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + } else { + // No more batches left; give an empty stream + bufferedStream + } + } + + /** + * Update partitionId if we have reached the end of our current partition, possibly skipping + * empty partitions on the way. + */ + private def skipToNextPartition() { + while (partitionId < numPartitions && + indexInPartition == spill.elementsPerPartition(partitionId)) { + partitionId += 1 + indexInPartition = 0L + } + } + + /** + * Return the next (K, C) pair from the deserialization stream and update partitionId, + * indexInPartition, indexInBatch and such to match its location. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more pairs are left, return null. + */ + private def readNextItem(): (K, C) = { + if (finished) { + return null + } + val k = deserStream.readObject().asInstanceOf[K] + val c = deserStream.readObject().asInstanceOf[C] + lastPartitionId = partitionId + // Start reading the next batch if we're done with this one + indexInBatch += 1 + if (indexInBatch == serializerBatchSize) { + batchStream = nextBatchStream() + compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) + deserStream = serInstance.deserializeStream(compressedStream) + indexInBatch = 0 + } + // Update the partition location of the element we're reading + indexInPartition += 1 + skipToNextPartition() + // If we've finished reading the last partition, remember that we're done + if (partitionId == numPartitions) { + finished = true + deserStream.close() + } + (k, c) + } + + var nextPartitionToRead = 0 + + def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] { + val myPartition = nextPartitionToRead + nextPartitionToRead += 1 + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + if (nextItem == null) { + return false + } + } + assert(lastPartitionId >= myPartition) + // Check that we're still in the right partition; note that readNextItem will have returned + // null at EOF above so we would've returned false there + lastPartitionId == myPartition + } + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val item = nextItem + nextItem = null + item + } + } + } + + /** + * Return an iterator over all the data written to this object, grouped by partition and + * aggregated by the requested aggregator. For each partition we then have an iterator over its + * contents, and these are expected to be accessed in order (you can't "skip ahead" to one + * partition without reading the previous one). Guaranteed to return a key-value pair for each + * partition, in order of partition ID. + * + * For now, we just merge all the spilled files in once pass, but this can be modified to + * support hierarchical merging. + */ + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + val usingMap = aggregator.isDefined + val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + if (spills.isEmpty) { + // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps + // we don't even need to sort by anything other than partition ID + if (!ordering.isDefined) { + // The user isn't requested sorted keys, so only sort by partition ID, not key + val partitionComparator = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + } else { + // We do need to sort by both partition ID and key + groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) + } + } else { + // General case: merge spilled and in-memory data + merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) + } + } + + /** + * Return an iterator over all the data written to this object, aggregated by our aggregator. + */ + def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + + def stop(): Unit = { + spills.foreach(s => s.file.delete()) + spills.clear() + } + + def memoryBytesSpilled: Long = _memoryBytesSpilled + + def diskBytesSpilled: Long = _diskBytesSpilled + + /** + * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, + * group together the pairs for each partition into a sub-iterator. + * + * @param data an iterator of elements, assumed to already be sorted by partition ID + */ + private def groupByPartition(data: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = + { + val buffered = data.buffered + (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) + } + + /** + * An iterator that reads only the elements for a given partition ID from an underlying buffered + * stream, assuming this partition is the next one to be read. Used to make it easier to return + * partitioned iterators from our in-memory collection. + */ + private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)]) + extends Iterator[Product2[K, C]] + { + override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = data.next() + (elem._1._2, elem._2) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala index de61e1d17fe10..eb4de413867a0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -20,8 +20,9 @@ package org.apache.spark.util.collection /** * An append-only map that keeps track of its estimated size in bytes. */ -private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker { - +private[spark] class SizeTrackingAppendOnlyMap[K, V] + extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V] +{ override def update(key: K, value: V): Unit = { super.update(key, value) super.afterUpdate() diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala new file mode 100644 index 0000000000000..9e9c16c5a2962 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes. + */ +private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) + extends SizeTracker with SizeTrackingPairCollection[K, V] +{ + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + // Basic growable array data structure. We use a single array of AnyRef to hold both the keys + // and the values, so that we can sort them efficiently with KVArraySortDataFormat. + private var capacity = initialCapacity + private var curSize = 0 + private var data = new Array[AnyRef](2 * initialCapacity) + + /** Add an element into the buffer */ + def insert(key: K, value: V): Unit = { + if (curSize == capacity) { + growArray() + } + data(2 * curSize) = key.asInstanceOf[AnyRef] + data(2 * curSize + 1) = value.asInstanceOf[AnyRef] + curSize += 1 + afterUpdate() + } + + /** Total number of elements in buffer */ + override def size: Int = curSize + + /** Iterate over the elements of the buffer */ + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { + var pos = 0 + + override def hasNext: Boolean = pos < curSize + + override def next(): (K, V) = { + if (!hasNext) { + throw new NoSuchElementException + } + val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + pos += 1 + pair + } + } + + /** Double the size of the array because we've reached capacity */ + private def growArray(): Unit = { + if (capacity == (1 << 29)) { + // Doubling the capacity would create an array bigger than Int.MaxValue, so don't + throw new Exception("Can't grow buffer beyond 2^29 elements") + } + val newCapacity = capacity * 2 + val newArray = new Array[AnyRef](2 * newCapacity) + System.arraycopy(data, 0, newArray, 0, 2 * capacity) + data = newArray + capacity = newCapacity + resetSamples() + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = { + new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator) + iterator + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala new file mode 100644 index 0000000000000..faa4e2b12ddb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * A common interface for our size-tracking collections of key-value pairs, which are used in + * external operations. These all support estimating the size and obtaining a memory-efficient + * sorted iterator. + */ +// TODO: should extend Iterable[Product2[K, V]] instead of (K, V) +private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] { + /** Estimate the collection's current memory usage in bytes. */ + def estimateSize(): Long + + /** Iterate through the data in a given key order. This may destroy the underlying collection. */ + def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1cb2d9d3a53b..a41914a1a9d0c 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("ShuffledRDD") { testRDD(rdd => { // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner) }) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index ad20f9b937ac1..4bc4346c0a288 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark import java.lang.ref.WeakReference -import org.apache.spark.broadcast.Broadcast - -import scala.collection.mutable import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.language.postfixOps @@ -34,15 +31,28 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} - -class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - +import org.apache.spark.storage._ +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.storage.BroadcastBlockId +import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.ShuffleIndexBlockId + +/** + * An abstract base class for context cleaner tests, which sets up a context with a config + * suitable for cleaner tests and provides some utility functions. Subclasses can use different + * config options, in particular, a different shuffle manager class + */ +abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) + extends FunSuite with BeforeAndAfter with LocalSparkContext +{ implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() .setMaster("local[2]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) before { sc = new SparkContext(conf) @@ -55,6 +65,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } + //------ Helper functions ------ + + protected def newRDD() = sc.makeRDD(1 to 10) + protected def newPairRDD() = newRDD().map(_ -> 1) + protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) + protected def newBroadcast() = sc.broadcast(1 to 100) + + protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { + rdd.dependencies ++ rdd.dependencies.flatMap { dep => + getAllDependencies(dep.rdd) + } + } + val rdd = newShuffleRDD() + + // Get all the shuffle dependencies + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) + (rdd, shuffleDeps) + } + + protected def randomRdd() = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD() + case 1 => newShuffleRDD() + case 2 => newPairRDD.join(newPairRDD()) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + /** Run GC and make sure it actually has run */ + protected def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) + } + } + + protected def cleaner = sc.cleaner.get +} + + +/** + * Basic ContextCleanerSuite, which uses sort-based shuffle + */ +class ContextCleanerSuite extends ContextCleanerSuiteBase { test("cleanup RDD") { val rdd = newRDD().persist() val collected = rdd.collect().toList @@ -147,7 +210,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) @@ -180,12 +243,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo .setMaster("local-cluster[2, 1, 512]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) sc = new SparkContext(conf2) val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = broadcastBuffer.map(_.id) @@ -210,57 +274,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo case _ => false }, askSlaves = true).isEmpty) } +} - //------ Helper functions ------ - private def newRDD() = sc.makeRDD(1 to 10) - private def newPairRDD() = newRDD().map(_ -> 1) - private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) - private def newBroadcast() = sc.broadcast(1 to 100) +/** + * A copy of the shuffle tests for sort-based shuffle + */ +class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) { + test("cleanup shuffle") { + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) - private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { - def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { - rdd.dependencies ++ rdd.dependencies.flatMap { dep => - getAllDependencies(dep.rdd) - } - } - val rdd = newShuffleRDD() + // Explicit cleanup + shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) + tester.assertCleanup() - // Get all the shuffle dependencies - val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) - (rdd, shuffleDeps) + // Verify that shuffles can be re-executed after cleaning up + assert(rdd.collect().toList.equals(collected)) } - private def randomRdd() = { - val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD() - case 1 => newShuffleRDD() - case 2 => newPairRDD.join(newPairRDD()) - } - if (Random.nextBoolean()) rdd.persist() + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD() rdd.count() - rdd - } - private def randomBroadcast() = { - sc.broadcast(Random.nextInt(Int.MaxValue)) + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope + runGC() + postGCTester.assertCleanup() } - /** Run GC and make sure it actually has run */ - private def runGC() { - val weakRef = new WeakReference(new Object()) - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - // Wait until a weak reference object has been GCed - while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - Thread.sleep(200) + test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { + sc.stop() + + val conf2 = new SparkConf() + .setMaster("local-cluster[2, 1, 512]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + .set("spark.shuffle.manager", shuffleManager.getName) + sc = new SparkContext(conf2) + + val numRdds = 10 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId() + val broadcastIds = broadcastBuffer.map(_.id) + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) } - } - private def cleaner = sc.cleaner.get + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) + } } @@ -418,6 +507,7 @@ class CleanerTester( private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { case ShuffleBlockId(`shuffleId`, _, _) => true + case ShuffleIndexBlockId(`shuffleId`, _, _) => true case _ => false }, askSlaves = true) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index 47df00050c1e2..d7b2d2e1e330f 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -28,6 +28,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - System.setProperty("spark.shuffle.use.netty", "false") + System.clearProperty("spark.shuffle.use.netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index eae67c7747e82..b13ddf96bc77c 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -58,8 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, - NonJavaSerializableClass, - (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS)) + NonJavaSerializableClass](b, new HashPartitioner(NUM_BLOCKS)) c.setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -83,8 +82,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, - NonJavaSerializableClass, - (Int, NonJavaSerializableClass)](b, new HashPartitioner(3)) + NonJavaSerializableClass](b, new HashPartitioner(3)) c.setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -100,7 +98,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId @@ -126,7 +124,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val b = a.map(x => (x, x*2)) // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10)) + val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) @@ -141,19 +139,19 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } - test("shuffle using mutable pairs") { + test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test") def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs, + val results = new ShuffledRDD[Int, Int, Int](pairs, new HashPartitioner(2)).collect() - data.foreach { pair => results should contain (pair) } + data.foreach { pair => results should contain ((pair._1, pair._2)) } } - test("sorting using mutable pairs") { + test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test") @@ -162,10 +160,10 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) .sortByKey().collect() - results(0) should be (p(1, 11)) - results(1) should be (p(2, 22)) - results(2) should be (p(3, 33)) - results(3) should be (p(100, 100)) + results(0) should be ((1, 11)) + results(1) should be ((2, 22)) + results(2) should be ((3, 33)) + results(3) should be ((100, 100)) } test("cogroup using mutable pairs") { diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala new file mode 100644 index 0000000000000..5c02c00586ef4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.BeforeAndAfterAll + +class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + + override def beforeAll() { + System.setProperty("spark.shuffle.manager", + "org.apache.spark.shuffle.sort.SortShuffleManager") + } + + override def afterAll() { + System.clearProperty("spark.shuffle.manager") + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 4953d565ae83a..8966eedd80ebc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -270,7 +270,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { // we can optionally shuffle to keep the upstream parallel val coalesced5 = data.coalesce(1, shuffle = true) val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd. - asInstanceOf[ShuffledRDD[_, _, _, _]] != null + asInstanceOf[ShuffledRDD[_, _, _]] != null assert(isEquals) // when shuffling, we can increase the number of partitions @@ -730,9 +730,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { // Any ancestors before the shuffle are not considered assert(ancestors4.size === 0) - assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0) + assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 0) assert(ancestors5.size === 3) - assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1) + assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1) assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0) assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 0b7ad184a46d2..7de5df6e1c8bd 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultA = rddA.reduceByKey(math.max).collect() assert(resultA.length == 50000) resultA.foreach { case(k, v) => - k match { - case 0 => assert(v == 1) - case 25000 => assert(v == 50001) - case 49999 => assert(v == 99999) - case _ => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") } } @@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { val resultB = rddB.groupByKey().collect() assert(resultB.length == 25000) resultB.foreach { case(i, seq) => - i match { - case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) - case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003)) - case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999)) - case _ => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") } } @@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { case 0 => assert(seq1.toSet == Set[Int](0)) assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) case 5000 => assert(seq1.toSet == Set[Int](5000)) assert(seq2.toSet == Set[Int]()) @@ -369,10 +367,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } } - -/** - * A dummy class that always returns the same hash code, to easily test hash collisions - */ -case class FixedHashObject(v: Int, h: Int) extends Serializable { - override def hashCode(): Int = h -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala new file mode 100644 index 0000000000000..ddb5df40360e9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -0,0 +1,566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class ExternalSorterSuite extends FunSuite with LocalSparkContext { + test("empty data stream") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter.iterator.toSeq === Seq()) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(3)), None, None) + assert(sorter2.iterator.toSeq === Seq()) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assert(sorter3.iterator.toSeq === Seq()) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), None, None) + assert(sorter4.iterator.toSeq === Seq()) + sorter4.stop() + } + + test("few elements per partition") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + sorter.write(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter.stop() + + // Only aggregator + val sorter2 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(7)), None, None) + sorter2.write(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter2.stop() + + // Only ordering + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + sorter3.write(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter4.write(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter4.stop() + } + + test("empty partitions with spilling") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), None, None) + sorter.write(elements) + assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + test("spilling in local cluster") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // reduceByKey - should spill ~8 times + val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max).collect() + assert(resultA.length == 50000) + resultA.foreach { case(k, v) => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") + } + } + + // groupByKey - should spill ~17 times + val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 25000) + resultB.foreach { case(i, seq) => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") + } + } + + // cogroup - should spill ~7 times + val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 10000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) + case 5000 => + assert(seq1.toSet == Set[Int](5000)) + assert(seq2.toSet == Set[Int]()) + case 9999 => + assert(seq1.toSet == Set[Int](9999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + + // larger cogroup - should spill ~7 times + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } + } + + test("spilling in local cluster with many reduce tasks") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + + // reduceByKey - should spill ~4 times per executor + val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max _, 100).collect() + assert(resultA.length == 50000) + resultA.foreach { case(k, v) => + if (v != k * 2 + 1) { + fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}") + } + } + + // groupByKey - should spill ~8 times per executor + val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultB = rddB.groupByKey(100).collect() + assert(resultB.length == 25000) + resultB.foreach { case(i, seq) => + val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3) + if (seq.toSet != expected) { + fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}") + } + } + + // cogroup - should spill ~4 times per executor + val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i)) + val resultC = rddC1.cogroup(rddC2, 100).collect() + assert(resultC.length == 10000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000)) + case 1 => + assert(seq1.toSet == Set[Int](1)) + assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001)) + case 5000 => + assert(seq1.toSet == Set[Int](5000)) + assert(seq2.toSet == Set[Int]()) + case 9999 => + assert(seq1.toSet == Set[Int](9999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + + // larger cogroup - should spill ~4 times per executor + val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultD = rddD1.cogroup(rddD2).collect() + assert(resultD.length == 5000) + resultD.foreach { case(i, (seq1, seq2)) => + val expected = Set(i * 2, i * 2 + 1) + if (seq1.toSet != expected) { + fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}") + } + if (seq2.toSet != expected) { + fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") + } + } + } + + test("cleanup of intermediate files in sorter") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + + val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter2.write((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + sorter2.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter if there are errors") { + val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + intercept[SparkException] { + sorter.write((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in shuffle") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val data = sc.parallelize(0 until 100000, 2).map(i => (i, i)) + assert(data.reduceByKey(_ + _).count() === 100000) + + // After the shuffle, there should be only 4 files on disk: our two map output files and + // their index files. All other intermediate files should've been deleted. + assert(diskBlockManager.getAllFiles().length === 4) + } + + test("cleanup of intermediate files in shuffle with errors") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + + val data = sc.parallelize(0 until 100000, 2).map(i => { + if (i == 99990) { + throw new Exception("Intentional failure") + } + (i, i) + }) + intercept[SparkException] { + data.reduceByKey(_ + _).count() + } + + // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index. + // All other files (map 2's output and intermediate merge files) should've been deleted. + assert(diskBlockManager.getAllFiles().length === 2) + } + + test("no partial aggregation or sorting") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation without spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill, no ordering") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("partial aggregation with spill, with ordering") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet + val expected = (0 until 3).map(p => { + (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) + }).toSet + assert(results === expected) + } + + test("sorting without aggregation, no spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } + + test("sorting without aggregation, with spill") { + val conf = new SparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val ord = implicitly[Ordering[Int]] + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + sorter.write((0 until 100000).iterator.map(i => (i, i))) + val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq + val expected = (0 until 3).map(p => { + (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) + }).toSeq + assert(results === expected) + } + + test("spilling with hash collisions") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: String) = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = + buffer1 ++= buffer2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner _, mergeValue _, mergeCombiners _) + + val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + + val collisionPairs = Seq( + ("Aa", "BB"), // 2112 + ("to", "v1"), // 3707 + ("variants", "gelato"), // -1249574770 + ("Teheran", "Siblings"), // 231609873 + ("misused", "horsemints"), // 1069518484 + ("isohel", "epistolaries"), // -1179291542 + ("righto", "buzzards"), // -931102253 + ("hierarch", "crinolines"), // -1732884796 + ("inwork", "hypercatalexes"), // -1183663690 + ("wainages", "presentencing"), // 240183619 + ("trichothecenes", "locular"), // 339006536 + ("pomatoes", "eructation") // 568647356 + ) + + collisionPairs.foreach { case (w1, w2) => + // String.hashCode is documented to use a specific algorithm, but check just in case + assert(w1.hashCode === w2.hashCode) + } + + val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ + collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) + + sorter.write(toInsert) + + // A map of collision pairs in both directions + val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap + + // Avoid map.size or map.iterator.length because this destructively sorts the underlying map + var count = 0 + + val it = sorter.iterator + while (it.hasNext) { + val kv = it.next() + val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1)) + assert(kv._2.equals(expectedValue)) + count += 1 + } + assert(count === 100000 + collisionPairs.size * 2) + } + + test("spilling with many hash collisions") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.0001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) + + // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes + // problems if the map fails to group together the objects with the same code (SPARK-2043). + val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) + sorter.write(toInsert.iterator) + + val it = sorter.iterator + var count = 0 + while (it.hasNext) { + val kv = it.next() + assert(kv._2 === 10) + count += 1 + } + assert(count === 10000) + } + + test("spilling with hash collisions using the Int.MaxValue key") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: Int) = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + + val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) + val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) + + sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NoSuchElementException + it.next() + } + } + + test("spilling with null keys and values") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + def createCombiner(i: String) = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + + sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + (null.asInstanceOf[String], "1"), + ("1", null.asInstanceOf[String]), + (null.asInstanceOf[String], null.asInstanceOf[String]) + )) + + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NullPointerException + it.next() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala new file mode 100644 index 0000000000000..c787b5f066e00 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** + * A dummy class that always returns the same hash code, to easily test hash collisions + */ +case class FixedHashObject(v: Int, h: Int) extends Serializable { + override def hashCode(): Int = h +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index 5318b8da6412a..714f3b81c9dad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD} private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner) + val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classTag[VD] == ClassTag.Int) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index a565d3b28bf52..b27485953f719 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -33,7 +33,7 @@ private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 672343fbbed2e..a8bbd55861954 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -295,6 +295,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("network"))) + .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("collection"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 392a7f3be3904..30712f03cab4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -49,7 +49,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) @@ -62,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(row => mutablePair.update(row, null)) } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -73,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(null, r)) } val partitioner = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 174eda8f1a72c..0027f3cf1fc79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan) iter.take(limit).map(row => mutablePair.update(false, row)) } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) + val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.mapPartitions(_.take(limit).map(_._2)) } From 894d48ffb8c91e347ab60c58de983e1aaf181188 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 30 Jul 2014 21:30:13 -0700 Subject: [PATCH 048/170] [SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs Author: Reynold Xin Closes #1675 from rxin/unionrdd and squashes the following commits: 941d316 [Reynold Xin] Clear RDDs for checkpointing. c9f05f2 [Reynold Xin] [SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs --- .../scala/org/apache/spark/rdd/UnionRDD.scala | 41 ++++++++++++++----- .../scala/org/apache/spark/rdd/RDDSuite.scala | 12 ++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 21c6e07d69f90..197167ecad0bd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -25,21 +25,32 @@ import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi -private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) +/** + * Partition for UnionRDD. + * + * @param idx index of the partition + * @param rdd the parent RDD this partition refers to + * @param parentRddIndex index of the parent RDD this partition refers to + * @param parentRddPartitionIndex index of the partition within the parent RDD + * this partition refers to + */ +private[spark] class UnionPartition[T: ClassTag]( + idx: Int, + @transient rdd: RDD[T], + val parentRddIndex: Int, + @transient parentRddPartitionIndex: Int) extends Partition { - var split: Partition = rdd.partitions(splitIndex) - - def iterator(context: TaskContext) = rdd.iterator(split, context) + var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(split) + def preferredLocations() = rdd.preferredLocations(parentPartition) override val index: Int = idx @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) + parentPartition = rdd.partitions(parentRddPartitionIndex) oos.defaultWriteObject() } } @@ -47,14 +58,14 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd @DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, - @transient var rdds: Seq[RDD[T]]) + var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { val array = new Array[Partition](rdds.map(_.partitions.size).sum) var pos = 0 - for (rdd <- rdds; split <- rdd.partitions) { - array(pos) = new UnionPartition(pos, rdd, split.index) + for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { + array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) pos += 1 } array @@ -70,9 +81,17 @@ class UnionRDD[T: ClassTag]( deps } - override def compute(s: Partition, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionPartition[T]].iterator(context) + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val part = s.asInstanceOf[UnionPartition[T]] + val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] + parentRdd.iterator(part.parentPartition, context) + } override def getPreferredLocations(s: Partition): Seq[String] = s.asInstanceOf[UnionPartition[T]].preferredLocations() + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8966eedd80ebc..ae6e52587584f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -121,6 +121,18 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(union.partitioner === nums1.partitioner) } + test("UnionRDD partition serialized size should be small") { + val largeVariable = new Array[Byte](1000 * 1000) + val rdd1 = sc.parallelize(1 to 10, 2).map(i => largeVariable.length) + val rdd2 = sc.parallelize(1 to 10, 3) + + val ser = SparkEnv.get.closureSerializer.newInstance() + val union = rdd1.union(rdd2) + // The UnionRDD itself should be large, but each individual partition should be small. + assert(ser.serialize(union).limit() > 2000) + assert(ser.serialize(union.partitions.head).limit() < 2000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] From 118c1c422d3dfbfb2277995062678f0a808af6c3 Mon Sep 17 00:00:00 2001 From: derek ma Date: Wed, 30 Jul 2014 21:37:59 -0700 Subject: [PATCH 049/170] Required AM memory is "amMem", not "args.amMemory" "ERROR yarn.Client: Required AM memory (1024) is above the max threshold (1048) of this cluster" appears if this code is not changed. obviously, 1024 is less than 1048, so change this Author: derek ma Closes #1494 from maji2014/master and squashes the following commits: b0f6640 [derek ma] Required AM memory is "amMem", not "args.amMemory" --- .../main/scala/org/apache/spark/deploy/yarn/ClientBase.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index a1298e8f30b5c..b7e8636e02eb2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -109,7 +109,7 @@ trait ClientBase extends Logging { if (amMem > maxMem) { val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." - .format(args.amMemory, maxMem) + .format(amMem, maxMem) logError(errorMessage) throw new IllegalArgumentException(errorMessage) } From a7c305b86b3b83645ae5ff5d3dfeafc20c443204 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 30 Jul 2014 21:57:32 -0700 Subject: [PATCH 050/170] [SPARK-2340] Resolve event logging and History Server paths properly We resolve relative paths to the local `file:/` system for `--jars` and `--files` in spark submit (#853). We should do the same for the history server. Author: Andrew Or Closes #1280 from andrewor14/hist-serv-fix and squashes the following commits: 13ff406 [Andrew Or] Merge branch 'master' of github.com:apache/spark into hist-serv-fix b393e17 [Andrew Or] Strip trailing "/" from logging directory 622a471 [Andrew Or] Fix test in EventLoggingListenerSuite 0e20f71 [Andrew Or] Shift responsibility of resolving paths up one level b037c0c [Andrew Or] Use resolved paths for everything in history server c7e36ee [Andrew Or] Resolve paths for event logging too 40e3933 [Andrew Or] Resolve history server file paths --- .../deploy/history/FsHistoryProvider.scala | 34 ++++++++++--------- .../spark/deploy/history/HistoryPage.scala | 2 +- .../spark/deploy/history/HistoryServer.scala | 6 ++-- .../history/HistoryServerArguments.scala | 5 +-- .../scheduler/EventLoggingListener.scala | 6 ++-- .../org/apache/spark/util/FileLogger.scala | 2 +- .../scheduler/EventLoggingListenerSuite.scala | 2 +- 7 files changed, 28 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 01e7065c17b69..6d2d4cef1ee46 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -36,11 +36,11 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis conf.getInt("spark.history.updateInterval", 10)) * 1000 private val logDir = conf.get("spark.history.fs.logDirectory", null) - if (logDir == null) { - throw new IllegalArgumentException("Logging directory must be specified.") - } + private val resolvedLogDir = Option(logDir) + .map { d => Utils.resolveURI(d) } + .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } - private val fs = Utils.getHadoopFileSystem(logDir) + private val fs = Utils.getHadoopFileSystem(resolvedLogDir) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -76,14 +76,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(logDir) + val path = new Path(resolvedLogDir) if (!fs.exists(path)) { throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(logDir)) + "Logging directory specified does not exist: %s".format(resolvedLogDir)) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(logDir)) + "Logging directory specified is not a directory: %s".format(resolvedLogDir)) } checkForLogs() @@ -95,15 +95,16 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis override def getAppUI(appId: String): SparkUI = { try { - val appLogDir = fs.getFileStatus(new Path(logDir, appId)) - loadAppInfo(appLogDir, true)._2 + val appLogDir = fs.getFileStatus(new Path(resolvedLogDir.toString, appId)) + val (_, ui) = loadAppInfo(appLogDir, renderUI = true) + ui } catch { case e: FileNotFoundException => null } } override def getConfig(): Map[String, String] = - Map(("Event Log Location" -> logDir)) + Map("Event Log Location" -> resolvedLogDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -114,14 +115,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(logDir)) + val logStatus = fs.listStatus(new Path(resolvedLogDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs.filter { - dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE)) + val logInfos = logDirs.filter { dir => + fs.isFile(new Path(dir.getPath, EventLoggingListener.APPLICATION_COMPLETE)) } val currentApps = Map[String, ApplicationHistoryInfo]( - appList.map(app => (app.id -> app)):_*) + appList.map(app => app.id -> app):_*) // For any application that either (i) is not listed or (ii) has changed since the last time // the listing was created (defined by the log dir's modification time), load the app's info. @@ -131,7 +132,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val curr = currentApps.getOrElse(dir.getPath().getName(), null) if (curr == null || curr.lastUpdated < getModificationTime(dir)) { try { - newApps += loadAppInfo(dir, false)._1 + val (app, _) = loadAppInfo(dir, renderUI = false) + newApps += app } catch { case e: Exception => logError(s"Failed to load app info from directory $dir.") } @@ -159,9 +161,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. */ private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { - val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs) val path = logDir.getPath val appId = path.getName + val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs) val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) val appListener = new ApplicationEventListener replayBus.addListener(appListener) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d7a3e3f120e67..c4ef8b63b0071 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -45,7 +45,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
    - { providerConfig.map(e =>
  • {e._1}: {e._2}
  • ) } + {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
{ if (allApps.size > 0) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index cacb9da8c947b..d1a64c1912cb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -25,9 +25,9 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.ui.{WebUI, SparkUI, UIUtils} +import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.SignalLogger /** * A web server that renders SparkUIs of completed applications. @@ -177,7 +177,7 @@ object HistoryServer extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() - val args = new HistoryServerArguments(conf, argStrings) + new HistoryServerArguments(conf, argStrings) val securityManager = new SecurityManager(conf) val providerName = conf.getOption("spark.history.provider") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index be9361b754fc3..25fc76c23e0fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf -import org.apache.spark.util.Utils /** * Command-line parser for the master. @@ -32,6 +31,7 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] args match { case ("--dir" | "-d") :: value :: tail => logDir = value + conf.set("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => @@ -42,9 +42,6 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case _ => printUsageAndExit(1) } - if (logDir != null) { - conf.set("spark.history.fs.logDirectory", logDir) - } } private def printUsageAndExit(exitCode: Int) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index ae6ca9f4e7bf5..406147f167bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -29,7 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{FileLogger, JsonProtocol} +import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} /** * A SparkListener that logs events to persistent storage. @@ -55,7 +55,7 @@ private[spark] class EventLoggingListener( private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis - val logDir = logBaseDir + "/" + name + val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS)) @@ -215,7 +215,7 @@ private[spark] object EventLoggingListener extends Logging { } catch { case e: Exception => logError("Exception in parsing logging info from directory %s".format(logDir), e) - EventLoggingInfo.empty + EventLoggingInfo.empty } } diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 9dcdafdd6350e..2e8fbf5a91ee7 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -52,7 +52,7 @@ private[spark] class FileLogger( override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) + private val fileSystem = Utils.getHadoopFileSystem(logDir) var fileIndex = 0 // Only used if compression is enabled diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 21e3db34b8b7a..10d8b299317ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -259,7 +259,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val expectedLogDir = logDirPath.toString - assert(eventLogger.logDir.startsWith(expectedLogDir)) + assert(eventLogger.logDir.contains(expectedLogDir)) // Begin listening for events that trigger asserts val eventExistenceListener = new EventExistenceListener(eventLogger) From 4fb259353f616822c32537e3f031944a6d2a09a8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 30 Jul 2014 22:40:57 -0700 Subject: [PATCH 051/170] [SPARK-2737] Add retag() method for changing RDDs' ClassTags. The Java API's use of fake ClassTags doesn't seem to cause any problems for Java users, but it can lead to issues when passing JavaRDDs' underlying RDDs to Scala code (e.g. in the MLlib Java API wrapper code). If we call collect() on a Scala RDD with an incorrect ClassTag, this causes ClassCastExceptions when we try to allocate an array of the wrong type (for example, see SPARK-2197). There are a few possible fixes here. An API-breaking fix would be to completely remove the fake ClassTags and require Java API users to pass java.lang.Class instances to all parallelize() calls and add returnClass fields to all Function implementations. This would be extremely verbose. Instead, this patch adds internal APIs to "repair" a Scala RDD with an incorrect ClassTag by wrapping it and overriding its ClassTag. This should be okay for cases where the Scala code that calls collect() knows what type of array should be allocated, which is the case in the MLlib wrappers. Author: Josh Rosen Closes #1639 from JoshRosen/SPARK-2737 and squashes the following commits: 572b4c8 [Josh Rosen] Replace newRDD[T] with mapPartitions(). 469d941 [Josh Rosen] Preserve partitioner in retag(). af78816 [Josh Rosen] Allow retag() to get classTag implicitly. d1d54e6 [Josh Rosen] [SPARK-2737] Add retag() method for changing RDDs' ClassTags. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 17 +++++++++++++++++ .../java/org/apache/spark/JavaAPISuite.java | 17 +++++++++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 8 ++++++++ 3 files changed, 42 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 74ac97091fd0b..e1c49e35abecd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1236,6 +1236,23 @@ abstract class RDD[T: ClassTag]( /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context = sc + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(cls: Class[T]): RDD[T] = { + val classTag: ClassTag[T] = ClassTag.apply(cls) + this.retag(classTag) + } + + /** + * Private API for changing an RDD's ClassTag. + * Used for internal Java <-> Scala API compatibility. + */ + private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { + this.mapPartitions(identity, preservesPartitioning = true)(classTag) + } + // Avoid handling doCheckpoint multiple times to prevent excessive recursion @transient private var doCheckpointCalled = false diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e8bd65f8e4507..fab64a54e2479 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1245,4 +1245,21 @@ public Tuple2 call(Integer i) { Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } + + private static class SomeCustomClass implements Serializable { + public SomeCustomClass() { + // Intentionally left blank + } + } + + @Test + public void collectUnderlyingScalaRDD() { + List data = new ArrayList(); + for (int i = 0; i < 100; i++) { + data.add(new SomeCustomClass()); + } + JavaRDD rdd = sc.parallelize(data); + SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); + Assert.assertEquals(data.size(), collected.length); + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ae6e52587584f..b31e3a09e5b9c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.scalatest.FunSuite @@ -26,6 +27,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ class RDDSuite extends FunSuite with SharedSparkContext { @@ -718,6 +720,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(ids.length === n) } + test("retag with implicit ClassTag") { + val jsc: JavaSparkContext = new JavaSparkContext(sc) + val jrdd: JavaRDD[String] = jsc.parallelize(Seq("A", "B", "C").asJava) + jrdd.rdd.retag.collect() + } + test("getNarrowAncestors") { val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1) From 5a110da25f15694773d6f7c6ee63c5b08ada4eb0 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 30 Jul 2014 22:46:30 -0700 Subject: [PATCH 052/170] [SPARK-2497] Included checks for module symbols too. Author: Prashant Sharma Closes #1463 from ScrapCodes/SPARK-2497/mima-exclude-all and squashes the following commits: 72077b1 [Prashant Sharma] Check separately for module symbols. cd96192 [Prashant Sharma] SPARK-2497 Produce "member excludes" irrespective of the fact that class itself is excluded or not. --- .../spark/tools/GenerateMIMAIgnore.scala | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 566983675bff5..16ff89a8a9809 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -68,12 +68,11 @@ object GenerateMIMAIgnore { for (className <- classes) { try { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) - val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary. + val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) - val developerApi = isDeveloperApi(classSymbol) - val experimental = isExperimental(classSymbol) - + val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) + val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) /* Inner classes defined within a private[spark] class or object are effectively invisible, so we account for them as package private. */ lazy val indirectlyPrivateSpark = { @@ -87,10 +86,9 @@ object GenerateMIMAIgnore { } if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { ignoredClasses += className - } else { - // check if this class has package-private/annotated members. - ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } + // check if this class has package-private/annotated members. + ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol) } catch { case _: Throwable => println("Error instrumenting class:" + className) @@ -115,8 +113,9 @@ object GenerateMIMAIgnore { } private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { - classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ + classSymbol.typeSignature.members.filterNot(x => + x.fullName.startsWith("java") || x.fullName.startsWith("scala")) + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ getInnerFunctions(classSymbol) } @@ -137,8 +136,7 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") || - name.contains("repl") + name.contains("Hive") } /** From 669e3f05895d9dfa37abf60f60aecebb03988e50 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Wed, 30 Jul 2014 23:37:25 -0700 Subject: [PATCH 053/170] automatically set master according to `spark.master` in `spark-defaults.... automatically set master according to `spark.master` in `spark-defaults.conf` Author: CrazyJvm Closes #1644 from CrazyJvm/standalone-guide and squashes the following commits: bb12b95 [CrazyJvm] automatically set master according to `spark.master` in `spark-defaults.conf` --- docs/spark-standalone.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index ad8b6c0e51a78..2fb30765f35e8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,9 +242,6 @@ To run an interactive Spark shell against the cluster, run the following command ./bin/spark-shell --master spark://IP:PORT -Note that if you are running spark-shell from one of the spark cluster machines, the `bin/spark-shell` script will -automatically set MASTER from the `SPARK_MASTER_IP` and `SPARK_MASTER_PORT` variables in `conf/spark-env.sh`. - You can also pass an option `--cores ` to control the number of cores that spark-shell uses on the cluster. # Launching Compiled Spark Applications From 92ca910eb866701e01b987a4f5003564b4785959 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 31 Jul 2014 10:25:40 -0700 Subject: [PATCH 054/170] [SPARK-2762] SparkILoop leaks memory in multi-repl configurations This pull request is a small refactor so that a partial function (hence a closure) is not created. Instead, a regular function is used. The behavior of the code is not changed. Author: Timothy Hunter Closes #1674 from thunterdb/closure_issue and squashes the following commits: e1e664d [Timothy Hunter] simplify closure --- .../org/apache/spark/repl/SparkILoop.scala | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e1db4d5395ab9..6f9fa0d9f2b25 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -557,29 +557,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, if (isReplPower) powerCommands else Nil )*/ - val replayQuestionMessage = + private val replayQuestionMessage = """|That entry seems to have slain the compiler. Shall I replay |your session? I can re-run each line except the last one. |[y/n] """.trim.stripMargin - private val crashRecovery: PartialFunction[Throwable, Boolean] = { - case ex: Throwable => - echo(intp.global.throwableAsString(ex)) - - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true + private def crashRecovery(ex: Throwable): Boolean = { + echo(ex.toString) + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true } /** The main read-eval-print loop for the repl. It calls @@ -605,7 +603,10 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } def innerLoop() { - if ( try processLine(readOneLine()) catch crashRecovery ) + val shouldContinue = try { + processLine(readOneLine()) + } catch {case t: Throwable => crashRecovery(t)} + if (shouldContinue) innerLoop() } innerLoop() From 3072b96026fa3e63e8eef780f2b04dd81f11ea27 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 31 Jul 2014 11:15:25 -0700 Subject: [PATCH 055/170] [SPARK-2743][SQL] Resolve original attributes in ParquetTableScan Author: Michael Armbrust Closes #1647 from marmbrus/parquetCase and squashes the following commits: a1799b7 [Michael Armbrust] move comment 2a2a68b [Michael Armbrust] Merge remote-tracking branch 'apache/master' into parquetCase bb35d5b [Michael Armbrust] Fix test case that produced an invalid plan. e6870bf [Michael Armbrust] Better error message. 539a2e1 [Michael Armbrust] Resolve original attributes in ParquetTableScan --- .../sql/parquet/ParquetTableOperations.scala | 14 ++++++++++---- .../spark/sql/parquet/ParquetQuerySuite.scala | 14 +------------- .../spark/sql/parquet/HiveParquetSuite.scala | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 912a9f002b7d1..759a2a586b926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -51,13 +51,20 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext} * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ case class ParquetTableScan( - // note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - output: Seq[Attribute], + attributes: Seq[Attribute], relation: ParquetRelation, columnPruningPred: Seq[Expression]) extends LeafNode { + // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes + // by exprId. note: output cannot be transient, see + // https://issues.apache.org/jira/browse/SPARK-1367 + val output = attributes.map { a => + relation.output + .find(o => o.exprId == a.exprId) + .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) + } + override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -110,7 +117,6 @@ case class ParquetTableScan( ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - this } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 561f5b4a49965..8955455ec98c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -209,19 +209,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { - SparkPlan.currentContext.set(TestSQLContext) - val scanner = new ParquetTableScan( - ParquetTestData.testData.output, - ParquetTestData.testData, - Seq()) - val projected = scanner.pruneColumns(ParquetTypesConverter - .convertToAttributes(MessageTypeParser - .parseMessageType(ParquetTestData.subTestSchema))) - assert(projected.output.size === 2) - val result = projected - .execute() - .map(_.copy()) - .collect() + val result = ParquetTestData.testData.select('myboolean, 'mylong).collect() result.zipWithIndex.foreach { case (row, index) => { if (index % 3 == 0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 3bfe49a760be5..47526e3596e44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} @@ -27,6 +29,8 @@ import org.apache.spark.util.Utils // Implicits import org.apache.spark.sql.hive.test.TestHive._ +case class Cases(lower: String, UPPER: String) + class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val dirname = Utils.createTempDir() @@ -55,6 +59,19 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft Utils.deleteRecursively(dirname) } + test("Case insensitive attribute names") { + val tempFile = File.createTempFile("parquet", "") + tempFile.delete() + sparkContext.parallelize(1 to 10) + .map(_.toString) + .map(i => Cases(i, i)) + .saveAsParquetFile(tempFile.getCanonicalPath) + + parquetFile(tempFile.getCanonicalPath).registerAsTable("cases") + hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + } + test("SELECT on Parquet table") { val rdd = hql("SELECT * FROM testsource").collect() assert(rdd != null) From 72cfb13987bab07461266905930f84619b3a0068 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 31 Jul 2014 11:26:43 -0700 Subject: [PATCH 056/170] [SPARK-2397][SQL] Deprecate LocalHiveContext LocalHiveContext is redundant with HiveContext. The only difference is it creates `./metastore` instead of `./metastore_db`. Author: Michael Armbrust Closes #1641 from marmbrus/localHiveContext and squashes the following commits: e5ec497 [Michael Armbrust] Add deprecation version 626e056 [Michael Armbrust] Don't remove from imports yet 905cc5f [Michael Armbrust] Merge remote-tracking branch 'apache/master' into localHiveContext 1c2727e [Michael Armbrust] Deprecate LocalHiveContext --- .../sbt_app_hive/src/main/scala/HiveApp.scala | 4 ++-- docs/sql-programming-guide.md | 6 +++--- .../spark/examples/sql/hive/HiveFromSpark.scala | 4 ++-- python/pyspark/sql.py | 6 ++++++ .../org/apache/spark/sql/hive/HiveContext.scala | 7 +++++-- .../org/apache/spark/sql/hive/TestHive.scala | 15 ++++++++++++--- 6 files changed, 30 insertions(+), 12 deletions(-) diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index 7257d17d10116..a21410f3b9813 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{ListBuffer, Queue} import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.hive.LocalHiveContext +import org.apache.spark.sql.hive.HiveContext case class Person(name: String, age: Int) @@ -34,7 +34,7 @@ object SparkSqlExample { case None => new SparkConf().setAppName("Simple Sql App") } val sc = new SparkContext(conf) - val hiveContext = new LocalHiveContext(sc) + val hiveContext = new HiveContext(sc) import hiveContext._ hql("DROP TABLE IF EXISTS src") diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 156e0aebdebe6..a047d32b6ee6c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -487,9 +487,9 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can also experiment with the `LocalHiveContext`, -which is similar to `HiveContext`, but creates a local copy of the `metastore` and `warehouse` -automatically. +not have an existing Hive deployment can still create a HiveContext. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current +directory. {% highlight scala %} // sc is an existing SparkContext. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 66a23fac39999..dc5290fb4f10e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.sql.hive import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ -import org.apache.spark.sql.hive.LocalHiveContext +import org.apache.spark.sql.hive.HiveContext object HiveFromSpark { case class Record(key: Int, value: String) @@ -31,7 +31,7 @@ object HiveFromSpark { // A local hive context creates an instance of the Hive Metastore in process, storing the // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. - val hiveContext = new LocalHiveContext(sc) + val hiveContext = new HiveContext(sc) import hiveContext._ hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 13f0ed4e35490..9388ead5eaad3 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer @@ -813,6 +815,10 @@ class LocalHiveContext(HiveContext): 130091 """ + def __init__(self, sparkContext, sqlContext=None): + HiveContext.__init__(self, sparkContext, sqlContext) + warnings.warn("LocalHiveContext is deprecated. Use HiveContext instead.", DeprecationWarning) + def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b413373345eea..27b444daba2d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,9 +42,12 @@ import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** - * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is - * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + * DEPRECATED: Use HiveContext instead. */ +@deprecated(""" + Use HiveContext instead. It will still create a local metastore if one is not specified. + However, note that the default directory is ./metastore_db, not ./metastore + """, "1.1") class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val metastorePath = new File("metastore").getCanonicalPath diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 9386008d02d51..c50e8c4b5c5d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -53,15 +53,24 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { self => // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. System.clearProperty("spark.hostPort") - override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure() { + set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + set("hive.metastore.warehouse.dir", warehousePath) + } + + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ lazy val hiveHome = envVarToFile("HIVE_HOME") From f1933123525e7c806f5fc0b0a46a78a7546f8b61 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 31 Jul 2014 11:35:38 -0700 Subject: [PATCH 057/170] SPARK-2028: Expose mapPartitionsWithInputSplit in HadoopRDD This allows users to gain access to the InputSplit which backs each partition. An alternative solution would have been to have a .withInputSplit() method which returns a new RDD[(InputSplit, (K, V))], but this is confusing because you could not cache this RDD or shuffle it, as InputSplit is not inherently serializable. Author: Aaron Davidson Closes #973 from aarondav/hadoop and squashes the following commits: 9c9112b [Aaron Davidson] Add JavaAPISuite test 9942cd7 [Aaron Davidson] Add Java API 1284a3a [Aaron Davidson] SPARK-2028: Expose mapPartitionsWithInputSplit in HadoopRDD --- .../apache/spark/api/java/JavaHadoopRDD.scala | 43 +++++++++++++++++++ .../spark/api/java/JavaNewHadoopRDD.scala | 43 +++++++++++++++++++ .../spark/api/java/JavaSparkContext.scala | 21 +++++---- .../org/apache/spark/rdd/HadoopRDD.scala | 32 ++++++++++++++ .../org/apache/spark/rdd/NewHadoopRDD.scala | 34 +++++++++++++++ .../java/org/apache/spark/JavaAPISuite.java | 26 ++++++++++- .../scala/org/apache/spark/FileSuite.scala | 34 ++++++++++++++- 7 files changed, 222 insertions(+), 11 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala new file mode 100644 index 0000000000000..0ae0b4ec042e2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapred.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.HadoopRDD + +@DeveloperApi +class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala new file mode 100644 index 0000000000000..ec4f3964d75e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag + +import org.apache.hadoop.mapreduce.InputSplit + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} +import org.apache.spark.rdd.NewHadoopRDD + +@DeveloperApi +class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) + (implicit override val kClassTag: ClassTag[K], implicit override val vClassTag: ClassTag[V]) + extends JavaPairRDD[K, V](rdd) { + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[R]( + f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + preservesPartitioning)(fakeClassTag))(fakeClassTag) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8a5f8088a05ca..d9d1c5955ca99 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -34,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns @@ -294,7 +294,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -314,7 +315,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat. @@ -333,7 +335,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat @@ -351,8 +354,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) implicit val ctagV: ClassTag[V] = ClassTag(valueClass) - new JavaPairRDD(sc.hadoopFile(path, - inputFormatClass, keyClass, valueClass)) + val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass) + new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } /** @@ -372,7 +375,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork conf: Configuration): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) + val rdd = sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** @@ -391,7 +395,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork vClass: Class[V]): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(kClass) implicit val ctagV: ClassTag[V] = ClassTag(vClass) - new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) + val rdd = sc.newAPIHadoopRDD(conf, fClass, kClass, vClass) + new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]]) } /** Build the union of two or more RDDs. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e521612ffc27c..8d92ea01d9a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -20,7 +20,9 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date import java.io.EOFException + import scala.collection.immutable.Map +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit @@ -39,6 +41,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.NextIterator /** @@ -232,6 +235,14 @@ class HadoopRDD[K, V]( new InterruptibleIterator[(K, V)](context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new HadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopPartition] @@ -272,4 +283,25 @@ private[spark] object HadoopRDD { conf.setInt("mapred.task.partition", splitId) conf.set("mapred.job.id", jobID.toString) } + + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class HadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[HadoopPartition] + val inputSplit = partition.inputSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f2b3a64bf1345..7dfec9a18ec67 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -32,6 +34,7 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD private[spark] class NewHadoopPartition( rddId: Int, @@ -157,6 +160,14 @@ class NewHadoopRDD[K, V]( new InterruptibleIterator(context, iter) } + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + override def getPreferredLocations(split: Partition): Seq[String] = { val theSplit = split.asInstanceOf[NewHadoopPartition] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") @@ -165,6 +176,29 @@ class NewHadoopRDD[K, V]( def getConf: Configuration = confBroadcast.value.value } +private[spark] object NewHadoopRDD { + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override def compute(split: Partition, context: TaskContext) = { + val partition = split.asInstanceOf[NewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} + private[spark] class WholeTextFileRDD( sc : SparkContext, inputFormatClass: Class[_ <: WholeTextFileInputFormat], diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index fab64a54e2479..56150caa5d6ba 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -25,19 +25,23 @@ import scala.Tuple3; import scala.Tuple4; - import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.FileSplit; +import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.SequenceFileInputFormat; import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapred.TextInputFormat; import org.apache.hadoop.mapreduce.Job; import org.junit.After; import org.junit.Assert; @@ -45,6 +49,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaHadoopRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -1262,4 +1267,23 @@ public void collectUnderlyingScalaRDD() { SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); Assert.assertEquals(data.size(), collected.length); } + + public void getHadoopInputSplits() { + String outDir = new File(tempDir, "output").getAbsolutePath(); + sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).saveAsTextFile(outDir); + + JavaHadoopRDD hadoopRDD = (JavaHadoopRDD) + sc.hadoopFile(outDir, TextInputFormat.class, LongWritable.class, Text.class); + List inputPaths = hadoopRDD.mapPartitionsWithInputSplit( + new Function2>, Iterator>() { + @Override + public Iterator call(InputSplit split, Iterator> it) + throws Exception { + FileSplit fileSplit = (FileSplit) split; + return Lists.newArrayList(fileSplit.getPath().toUri().getPath()).iterator(); + } + }, true).collect(); + Assert.assertEquals(Sets.newHashSet(inputPaths), + Sets.newHashSet(outDir + "/part-00000", outDir + "/part-00001")); + } } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c70e22cf09433..4a53d25012ad9 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -24,12 +24,14 @@ import scala.io.Source import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec -import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, TextOutputFormat} -import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.FunSuite import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils class FileSuite extends FunSuite with LocalSparkContext { @@ -318,4 +320,32 @@ class FileSuite extends FunSuite with LocalSparkContext { randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } + + test("Get input files via old Hadoop API") { + sc = new SparkContext("local", "test") + val outDir = new File(tempDir, "output").getAbsolutePath + sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir) + + val inputPaths = + sc.hadoopFile(outDir, classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[HadoopRDD[_, _]] + .mapPartitionsWithInputSplit { (split, part) => + Iterator(split.asInstanceOf[FileSplit].getPath.toUri.getPath) + }.collect() + assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + } + + test("Get input files via new Hadoop API") { + sc = new SparkContext("local", "test") + val outDir = new File(tempDir, "output").getAbsolutePath + sc.makeRDD(1 to 4, 2).saveAsTextFile(outDir) + + val inputPaths = + sc.newAPIHadoopFile(outDir, classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + .mapPartitionsWithInputSplit { (split, part) => + Iterator(split.asInstanceOf[NewFileSplit].getPath.toUri.getPath) + }.collect() + assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + } } From f68105df52902a1c65207d4f51bfdeb55cccf767 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 31 Jul 2014 11:51:20 -0700 Subject: [PATCH 058/170] SPARK-2664. Deal with `--conf` options in spark-submit that relate to fl... ...ags Author: Sandy Ryza Closes #1665 from sryza/sandy-spark-2664 and squashes the following commits: 0518c63 [Sandy Ryza] SPARK-2664. Deal with `--conf` options in spark-submit that relate to flags --- .../org/apache/spark/deploy/SparkSubmit.scala | 11 +++++--- .../spark/deploy/SparkSubmitArguments.scala | 26 +++++++++++-------- .../spark/deploy/SparkSubmitSuite.scala | 16 ++++++++++++ 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3df811c4ac5df..318509a67a36f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -184,7 +184,7 @@ object SparkSubmit { OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), // Yarn cluster only - OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), @@ -268,14 +268,17 @@ object SparkSubmit { } } + // Properties given with --conf are superceded by other options, but take precedence over + // properties in the defaults file. + for ((k, v) <- args.sparkProperties) { + sysProps.getOrElseUpdate(k, v) + } + // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Spark properties included on command line take precedence - sysProps ++= args.sparkProperties - (childArgs, childClasspath, sysProps, childMainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 01d0ae541a66b..dd044e6298760 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -58,7 +58,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) - loadDefaults() + mergeSparkProperties() checkRequiredArguments() /** Return default present in the currently defined defaults file. */ @@ -79,9 +79,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } - /** Fill in any undefined values based on the current properties file or built-in defaults. */ - private def loadDefaults(): Unit = { - + /** + * Fill in any undefined values based on the default properties file or options passed in through + * the '--conf' flag. + */ + private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user if (propertiesFile == null) { sys.env.get("SPARK_HOME").foreach { sparkHome => @@ -94,18 +96,20 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } } - val defaultProperties = getDefaultSparkProperties + val properties = getDefaultSparkProperties + properties.putAll(sparkProperties) + // Use properties file as fallback for values which have a direct analog to // arguments in this script. - master = Option(master).getOrElse(defaultProperties.get("spark.master").orNull) + master = Option(master).getOrElse(properties.get("spark.master").orNull) executorMemory = Option(executorMemory) - .getOrElse(defaultProperties.get("spark.executor.memory").orNull) + .getOrElse(properties.get("spark.executor.memory").orNull) executorCores = Option(executorCores) - .getOrElse(defaultProperties.get("spark.executor.cores").orNull) + .getOrElse(properties.get("spark.executor.cores").orNull) totalExecutorCores = Option(totalExecutorCores) - .getOrElse(defaultProperties.get("spark.cores.max").orNull) - name = Option(name).getOrElse(defaultProperties.get("spark.app.name").orNull) - jars = Option(jars).getOrElse(defaultProperties.get("spark.jars").orNull) + .getOrElse(properties.get("spark.cores.max").orNull) + name = Option(name).getOrElse(properties.get("spark.app.name").orNull) + jars = Option(jars).getOrElse(properties.get("spark.jars").orNull) // This supports env vars in older versions of Spark master = Option(master).getOrElse(System.getenv("MASTER")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a301cbd48a0c3..9190b05e2dba2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -253,6 +253,22 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.shuffle.spill") should be ("false") } + test("handles confs with flag equivalents") { + val clArgs = Seq( + "--deploy-mode", "cluster", + "--executor-memory", "5g", + "--class", "org.SomeClass", + "--conf", "spark.executor.memory=4g", + "--conf", "spark.master=yarn", + "thejar.jar", + "arg1", "arg2") + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.master") should be ("yarn-cluster") + mainClass should be ("org.apache.spark.deploy.yarn.Client") + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( From 4dbabb39a7bf248ac4f9b7f5eb2fe69e5047dcb3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 31 Jul 2014 12:18:40 -0700 Subject: [PATCH 059/170] SPARK-2749 [BUILD] Part 2. Fix a follow-on scalastyle error The test compile error is fixed, but the build still fails because of one scalastyle error. https://amplab.cs.berkeley.edu/jenkins/view/Spark/job/Spark-Master-Maven-pre-YARN/lastFailedBuild/hadoop.version=1.0.4,label=centos/console Author: Sean Owen Closes #1690 from srowen/SPARK-2749 and squashes the following commits: 1c9e7a6 [Sean Owen] Also: fix scalastyle error by wrapping a long line --- .../scala/org/apache/spark/tools/GenerateMIMAIgnore.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 16ff89a8a9809..bcf6d43ab34eb 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -114,9 +114,10 @@ object GenerateMIMAIgnore { private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members.filterNot(x => - x.fullName.startsWith("java") || x.fullName.startsWith("scala")) - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ - getInnerFunctions(classSymbol) + x.fullName.startsWith("java") || x.fullName.startsWith("scala") + ).filter(x => + isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x) + ).map(_.fullName) ++ getInnerFunctions(classSymbol) } def main(args: Array[String]) { From e5749a1342327263dc6b94ba470e392fbea703fa Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 31 Jul 2014 12:26:36 -0700 Subject: [PATCH 060/170] SPARK-2646. log4j initialization not quite compatible with log4j 2.x The logging code that handles log4j initialization leads to an stack overflow error when used with log4j 2.x, which has just been released. This occurs even a downstream project has correctly adjusted SLF4J bindings, and that is the right thing to do for log4j 2.x, since it is effectively a separate project from 1.x. Here is the relevant bit of Logging.scala: ``` private def initializeLogging() { // If Log4j is being used, but is not initialized, load a default properties file val binder = StaticLoggerBinder.getSingleton val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized && usingLog4j) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) log.info(s"Using Spark's default log4j profile: $defaultLogProps") case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } Logging.initialized = true // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } ``` The first minor issue is that there is a call to a logger inside this method, which is initializing logging. In this situation, it ends up causing the initialization to be called recursively until the stack overflow. It would be slightly tidier to log this only after Logging.initialized = true. Or not at all. But it's not the root problem, or else, it would not work at all now. The calls to log4j classes here always reference log4j 1.2 no matter what. For example, there is not getAllAppenders in log4j 2.x. That's fine. Really, "usingLog4j" means "using log4j 1.2" and "log4jInitialized" means "log4j 1.2 is initialized". usingLog4j should be false for log4j 2.x, because the initialization only matters for log4j 1.2. But, it's true, and that's the real issue. And log4jInitialized is always false, since calls to the log4j 1.2 API are stubs and no-ops in this setup, where the caller has swapped in log4j 2.x. Hence the loop. This is fixed, I believe, if "usingLog4j" can be false for log4j 2.x. The SLF4J static binding class has the same name for both versions, unfortunately, which causes the issue. However they're in different packages. For example, if the test included "... and begins with org.slf4j", it should work, as the SLF4J binding for log4j 2.x is provided by log4j 2.x at the moment, and is in package org.apache.logging.slf4j. Of course, I assume that SLF4J will eventually offer its own binding. I hope to goodness they at least name the binding class differently, or else this will again not work. But then some other check can probably be made. Author: Sean Owen Closes #1547 from srowen/SPARK-2646 and squashes the following commits: 92a9898 [Sean Owen] System.out -> System.err 94be4c7 [Sean Owen] Add back log message as System.out, with informational comment a7f8876 [Sean Owen] Updates from review 6f3c1d3 [Sean Owen] Remove log statement in logging initialization, and distinguish log4j 1.2 from 2.0, to avoid stack overflow in initialization --- .../main/scala/org/apache/spark/Logging.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 50d8e93e1f0d7..807ef3e9c9d60 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -45,10 +45,7 @@ trait Logging { initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) + log_ = LoggerFactory.getLogger(className.stripSuffix("$")) } log_ } @@ -110,23 +107,27 @@ trait Logging { } private def initializeLogging() { - // If Log4j is being used, but is not initialized, load a default properties file - val binder = StaticLoggerBinder.getSingleton - val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") - val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized && usingLog4j) { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized && usingLog4j12) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { case Some(url) => PropertyConfigurator.configure(url) - log.info(s"Using Spark's default log4j profile: $defaultLogProps") + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } Logging.initialized = true - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } From dc0865bc7e119fe507061c27069c17523b87dfea Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 31 Jul 2014 12:55:00 -0700 Subject: [PATCH 061/170] [SPARK-2511][MLLIB] add HashingTF and IDF This is roughly the TF-IDF implementation used in the Databricks Cloud Demo: http://databricks.com/cloud/ . Both `HashingTF` and `IDF` are implemented as transformers, similar to scikit-learn. Author: Xiangrui Meng Closes #1671 from mengxr/tfidf and squashes the following commits: 7d65888 [Xiangrui Meng] use JavaConverters._ 5fe9ec4 [Xiangrui Meng] fix unit test 6e214ec [Xiangrui Meng] add apache header cfd9aed [Xiangrui Meng] add Java-friendly methods move classes to mllib.feature 3814440 [Xiangrui Meng] add HashingTF and IDF --- .../spark/mllib/feature/HashingTF.scala | 79 +++++++ .../org/apache/spark/mllib/feature/IDF.scala | 194 ++++++++++++++++++ .../spark/mllib/feature/JavaTfIdfSuite.java | 66 ++++++ .../spark/mllib/feature/HashingTFSuite.scala | 52 +++++ .../apache/spark/mllib/feature/IDFSuite.scala | 63 ++++++ 5 files changed, 454 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala new file mode 100644 index 0000000000000..0f6d5809e098f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + * + * @param numFeatures number of features (default: 1000000) + */ +@Experimental +class HashingTF(val numFeatures: Int) extends Serializable { + + def this() = this(1000000) + + /** + * Returns the index of the input term. + */ + def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) + + /** + * Transforms the input document into a sparse term frequency vector. + */ + def transform(document: Iterable[_]): Vector = { + val termFrequencies = mutable.HashMap.empty[Int, Double] + document.foreach { term => + val i = indexOf(term) + termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + } + Vectors.sparse(numFeatures, termFrequencies.toSeq) + } + + /** + * Transforms the input document into a sparse term frequency vector (Java version). + */ + def transform(document: JavaIterable[_]): Vector = { + transform(document.asScala) + } + + /** + * Transforms the input document to term frequency vectors. + */ + def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { + dataset.map(this.transform) + } + + /** + * Transforms the input document to term frequency vectors (Java version). + */ + def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { + dataset.rdd.map(this.transform).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala new file mode 100644 index 0000000000000..7ed611a857acc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Inverse document frequency (IDF). + * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total + * number of documents and `d(t)` is the number of documents that contain term `t`. + */ +@Experimental +class IDF { + + // TODO: Allow different IDF formulations. + + private var brzIdf: BDV[Double] = _ + + /** + * Computes the inverse document frequency. + * @param dataset an RDD of term frequency vectors + */ + def fit(dataset: RDD[Vector]): this.type = { + brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + seqOp = (df, v) => df.add(v), + combOp = (df1, df2) => df1.merge(df2) + ).idf() + this + } + + /** + * Computes the inverse document frequency. + * @param dataset a JavaRDD of term frequency vectors + */ + def fit(dataset: JavaRDD[Vector]): this.type = { + fit(dataset.rdd) + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + val theIdf = brzIdf + val bcIdf = dataset.context.broadcast(theIdf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } + + /** Returns the IDF vector. */ + def idf(): Vector = { + if (!initialized) { + throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") + } + Vectors.fromBreeze(brzIdf) + } + + private def initialized: Boolean = brzIdf != null +} + +private object IDF { + + /** Document frequency aggregator. */ + class DocumentFrequencyAggregator extends Serializable { + + /** number of documents */ + private var m = 0L + /** document frequency vector */ + private var df: BDV[Long] = _ + + /** Adds a new document. */ + def add(doc: Vector): this.type = { + if (isEmpty) { + df = BDV.zeros(doc.size) + } + doc match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + if (sv.values(k) > 0) { + df(sv.indices(k)) += 1L + } + k += 1 + } + case dv: DenseVector => + val n = dv.size + var j = 0 + while (j < n) { + if (dv.values(j) > 0.0) { + df(j) += 1L + } + j += 1 + } + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + m += 1L + this + } + + /** Merges another. */ + def merge(other: DocumentFrequencyAggregator): this.type = { + if (!other.isEmpty) { + m += other.m + if (df == null) { + df = other.df.copy + } else { + df += other.df + } + } + this + } + + private def isEmpty: Boolean = m == 0L + + /** Returns the current IDF vector. */ + def idf(): BDV[Double] = { + if (isEmpty) { + throw new IllegalStateException("Haven't seen any document yet.") + } + val n = df.length + val inv = BDV.zeros[Double](n) + var j = 0 + while (j < n) { + inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + j += 1 + } + inv + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java new file mode 100644 index 0000000000000..e8d99f4ae43ae --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; + +public class JavaTfIdfSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void tfIdf() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala new file mode 100644 index 0000000000000..a599e0d938569 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LocalSparkContext + +class HashingTFSuite extends FunSuite with LocalSparkContext { + + test("hashing tf on a single doc") { + val hashingTF = new HashingTF(1000) + val doc = "a a b b c d".split(" ") + val n = hashingTF.numFeatures + val termFreqs = Seq( + (hashingTF.indexOf("a"), 2.0), + (hashingTF.indexOf("b"), 2.0), + (hashingTF.indexOf("c"), 1.0), + (hashingTF.indexOf("d"), 1.0)) + assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), + "index must be in range [0, #features)") + assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing") + val expected = Vectors.sparse(n, termFreqs) + assert(hashingTF.transform(doc) === expected) + } + + test("hashing tf on an RDD") { + val hashingTF = new HashingTF + val localDocs: Seq[Seq[String]] = Seq( + "a a b b b c d".split(" "), + "a b c d a b c".split(" "), + "c b a c b a a".split(" ")) + val docs = sc.parallelize(localDocs, 2) + assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala new file mode 100644 index 0000000000000..78a2804ff204b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IDFSuite extends FunSuite with LocalSparkContext { + + test("idf") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF + intercept[IllegalStateException] { + idf.idf() + } + intercept[IllegalStateException] { + idf.transform(termFrequencies) + } + idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((m.toDouble + 1.0) / (x + 1.0)) + }) + assert(idf.idf() ~== expected absTol 1e-12) + val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } +} From 49b361298b09d415de1857846367913495aecfa6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 31 Jul 2014 13:05:24 -0700 Subject: [PATCH 062/170] [SPARK-2523] [SQL] Hadoop table scan bug fixing (fix failing Jenkins maven test) This PR tries to resolve the broken Jenkins maven test issue introduced by #1439. Now, we create a single query test to run both the setup work and the test query. Author: Yin Huai Closes #1669 from yhuai/SPARK-2523-fixTest and squashes the following commits: 358af1a [Yin Huai] Make partition_based_table_scan_with_different_serde run atomically. --- ...t_serde-0-1436cccda63b78dd6e43a399da6cc474 | 0 ...t_serde-1-8d9bf54373f45bc35f8cb6e82771b154 | 0 ...t_serde-2-7816c17905012cf381abf93d230faa8d | 0 ...t_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc | 0 ..._serde-4-8caed2a6e80250a6d38a59388679c298} | 0 .../hive/execution/HiveTableScanSuite.scala | 45 ++++++++----------- 6 files changed, 19 insertions(+), 26 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 create mode 100644 sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 create mode 100644 sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d create mode 100644 sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc rename sql/hive/src/test/resources/golden/{partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 => partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298} (100%) diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-1436cccda63b78dd6e43a399da6cc474 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-1-8d9bf54373f45bc35f8cb6e82771b154 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-2-7816c17905012cf381abf93d230faa8d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-3-90089a6db3c3d8ee5ff5ea6b9153b3cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 b/sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 similarity index 100% rename from sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-0-8caed2a6e80250a6d38a59388679c298 rename to sql/hive/src/test/resources/golden/partition_based_table_scan_with_different_serde-4-8caed2a6e80250a6d38a59388679c298 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index bcb00f871d185..c5736723b47c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,32 +17,25 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.hive.test.TestHive - class HiveTableScanSuite extends HiveComparisonTest { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") - TestHive.reset() - - TestHive.hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) - | ROW FORMAT SERDE - | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' - | STORED AS RCFILE - """.stripMargin) - TestHive.hql("""FROM src - | INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') - | SELECT 100,100 LIMIT 1 - """.stripMargin) - TestHive.hql("""ALTER TABLE part_scan_test SET SERDE - | 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' - """.stripMargin) - TestHive.hql("""FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') - | SELECT 200,200 LIMIT 1 - """.stripMargin) - createQueryTest("partition_based_table_scan_with_different_serde", - "SELECT * from part_scan_test", false) + createQueryTest("partition_based_table_scan_with_different_serde", + """ + |CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' + |STORED AS RCFILE; + | + |FROM src + |INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-01') + |SELECT 100,100 LIMIT 1; + | + |ALTER TABLE part_scan_test SET SERDE + |'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe'; + | + |FROM src INSERT INTO TABLE part_scan_test PARTITION (ds='2010-01-02') + |SELECT 200,200 LIMIT 1; + | + |SELECT * from part_scan_test; + """.stripMargin) } From e02136214a6c2635e88c36b1f530a97e975d83e3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Jul 2014 14:35:09 -0700 Subject: [PATCH 063/170] Improvements to merge_spark_pr.py This commit fixes a couple of issues in the merge_spark_pr.py developer script: - Allow recovery from failed cherry-picks. - Fix detection of pull requests that have already been merged. Both of these fixes are useful when backporting changes. Author: Josh Rosen Closes #1668 from JoshRosen/pr-script-improvements and squashes the following commits: ff4f33a [Josh Rosen] Default SPARK_HOME to cwd(); detect missing JIRA credentials. ed5bc57 [Josh Rosen] Improvements for backporting using merge_spark_pr: --- dev/merge_spark_pr.py | 53 +++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index c44320239bbbf..53df9b5a3f1d5 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -29,7 +29,6 @@ import re import subprocess import sys -import tempfile import urllib2 try: @@ -39,15 +38,15 @@ JIRA_IMPORTED = False # Location of your Spark git development area -SPARK_HOME = os.environ.get("SPARK_HOME", "/home/patrick/Documents/spark") +SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site PR_REMOTE_NAME = os.environ.get("PR_REMOTE_NAME", "apache-github") # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "1234") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" @@ -129,7 +128,7 @@ def merge_pr(pr_num, target_ref): merge_message_flags = [] merge_message_flags += ["-m", title] - if body != None: + if body is not None: # We remove @ symbols from the body to avoid triggering e-mails # to people every time someone creates a public fork of Spark. merge_message_flags += ["-m", body.replace("@", "")] @@ -179,7 +178,14 @@ def cherry_pick(pr_num, merge_hash, default_branch): run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, pick_ref, pick_branch_name)) run_cmd("git checkout %s" % pick_branch_name) - run_cmd("git cherry-pick -sx %s" % merge_hash) + + try: + run_cmd("git cherry-pick -sx %s" % merge_hash) + except Exception as e: + msg = "Error cherry-picking: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and finish the cherry-pick. Finished?" + continue_maybe(msg) continue_maybe("Pick complete (local ref %s). Push to %s?" % ( pick_branch_name, PUSH_REMOTE_NAME)) @@ -280,6 +286,7 @@ def get_version_json(version_str): pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) +pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) url = pr["url"] title = pr["title"] @@ -289,19 +296,23 @@ def get_version_json(version_str): base_ref = pr["head"]["ref"] pr_repo_desc = "%s/%s" % (user_login, base_ref) -if pr["merged"] is True: +# Merged pull requests don't appear as merged in the GitHub API; +# Instead, they're closed by asfgit. +merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + +if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + print "Pull request %s has already been merged, assuming you want to backport" % pr_num - merge_commit_desc = run_cmd([ - 'git', 'log', '--merges', '--first-parent', - '--grep=pull request #%s' % pr_num, '--oneline']).split("\n")[0] - if merge_commit_desc == "": + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + "%s^{commit}" % merge_hash]).strip() != "" + if not commit_is_downloaded: fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - merge_hash = merge_commit_desc[:7] - message = merge_commit_desc[8:] - - print "Found: %s" % message - maybe_cherry_pick(pr_num, merge_hash, latest_branch) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) sys.exit(0) if not bool(pr["mergeable"]): @@ -323,9 +334,13 @@ def get_version_json(version_str): merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira(title, merged_refs, jira_comment) + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." From cc820502fb08f71b03237103153c34487b2600b4 Mon Sep 17 00:00:00 2001 From: kballou Date: Thu, 31 Jul 2014 14:58:52 -0700 Subject: [PATCH 064/170] Docs: monitoring, streaming programming guide Fix several awkward wordings and grammatical issues in the following documents: * docs/monitoring.md * docs/streaming-programming-guide.md Author: kballou Closes #1662 from kennyballou/grammar_fixes and squashes the following commits: e1b8ad6 [kballou] Docs: monitoring, streaming programming guide --- docs/monitoring.md | 4 ++-- docs/streaming-programming-guide.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 84073fe4d949a..d07ec4a57a2cc 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -33,7 +33,7 @@ application's UI after the application has finished. If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished application through Spark's history server, provided that the application's event logs exist. -You can start a the history server by executing: +You can start the history server by executing: ./sbin/start-history-server.sh @@ -106,7 +106,7 @@ follows:
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 90a0eef60c200..7b8b7933434c4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -939,7 +939,7 @@ Receiving multiple data streams can therefore be achieved by creating multiple i and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input stream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -980,7 +980,7 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task sizes, and therefore reduce the time taken to send them to the slaves. * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to From 492a195c5c4d68c85b8b1b48e3aa85165bbb5dc3 Mon Sep 17 00:00:00 2001 From: Rui Li Date: Thu, 31 Jul 2014 15:07:26 -0700 Subject: [PATCH 065/170] SPARK-2740: allow user to specify ascending and numPartitions for sortBy... It should be more convenient if user can specify ascending and numPartitions when calling sortByKey. Author: Rui Li Closes #1645 from lirui-intel/spark-2740 and squashes the following commits: fb5d52e [Rui Li] SPARK-2740: allow user to specify ascending and numPartitions for sortByKey --- .../scala/org/apache/spark/api/java/JavaPairRDD.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 47708cb2e78bd..76d4193e96aea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -783,6 +783,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sortByKey(comp, ascending) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + sortByKey(comp, ascending, numPartitions) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling * `collect` or `save` on the resulting RDD will return or output an ordered list of records From ef4ff00f87a4e8d38866f163f01741c2673e41da Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 31 Jul 2014 15:31:53 -0700 Subject: [PATCH 066/170] SPARK-2282: Reuse Socket for sending accumulator updates to Pyspark Prior to this change, every PySpark task completion opened a new socket to the accumulator server, passed its updates through, and then quit. I'm not entirely sure why PySpark always sends accumulator updates, but regardless this causes a very rapid buildup of ephemeral TCP connections that remain in the TCP_WAIT state for around a minute before being cleaned up. Rather than trying to allow these sockets to be cleaned up faster, this patch simply reuses the connection between tasks completions (since they're fed updates in a single-threaded manner by the DAGScheduler anyway). The only tricky part here was making sure that the AccumulatorServer was able to shutdown in a timely manner (i.e., stop polling for new data), and this was accomplished via minor feats of magic. I have confirmed that this patch eliminates the buildup of ephemeral sockets due to the accumulator updates. However, I did note that there were still significant sockets being created against the PySpark daemon port, but my machine was not able to create enough sockets fast enough to fail. This may not be the last time we've seen this issue, though. Author: Aaron Davidson Closes #1503 from aarondav/accum and squashes the following commits: b3e12f7 [Aaron Davidson] SPARK-2282: Reuse Socket for sending accumulator updates to Pyspark --- .../apache/spark/api/python/PythonRDD.scala | 20 ++++++++--- python/pyspark/accumulators.py | 34 +++++++++++++++---- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a9d758bf998c3..94d666aa92025 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -731,19 +731,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + /** + * We try to reuse a single Socket to transfer accumulator updates, as they are all added + * by the DAGScheduler's single-threaded actor anyway. + */ + @transient var socket: Socket = _ + + def openSocket(): Socket = synchronized { + if (socket == null || socket.isClosed) { + socket = new Socket(serverHost, serverPort) + } + socket + } + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { + : JList[Array[Byte]] = synchronized { if (serverHost == null) { // This happens on the worker node, where we just want to remember all the updates val1.addAll(val2) val1 } else { // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - // SPARK-2282: Immediately reuse closed sockets because we create one per task. - socket.setReuseAddress(true) + val socket = openSocket() val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) @@ -757,7 +768,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: if (byteRead == -1) { throw new SparkException("EOF reached before Python server acknowledged") } - socket.close() null } } diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 2204e9c9ca701..45d36e5d0e764 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -86,6 +86,7 @@ Exception:... """ +import select import struct import SocketServer import threading @@ -209,19 +210,38 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ + This handler will keep polling updates from the same socket until the + server is shutdown. + """ + def handle(self): from pyspark.accumulators import _accumulatorRegistry - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + +class AccumulatorServer(SocketServer.TCPServer): + """ + A simple TCP server that intercepts shutdown() in order to interrupt + our continuous polling on the handler. + """ + server_shutdown = False + def shutdown(self): + self.server_shutdown = True + SocketServer.TCPServer.shutdown(self) def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() From 8f51491ea78d8e88fc664c2eac3b4ac14226d98f Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Thu, 31 Jul 2014 19:32:16 -0700 Subject: [PATCH 067/170] [SPARK-2531 & SPARK-2436] [SQL] Optimize the BuildSide when planning BroadcastNestedLoopJoin. This PR resolves the following two tickets: - [SPARK-2531](https://issues.apache.org/jira/browse/SPARK-2531): BNLJ currently assumes the build side is the right relation. This patch refactors some of its logic to take into account a BuildSide properly. - [SPARK-2436](https://issues.apache.org/jira/browse/SPARK-2436): building on top of the above, we simply use the physical size statistics (if available) of both relations, and make the smaller relation the build side in the planner. Author: Zongheng Yang Closes #1448 from concretevitamin/bnlj-buildSide and squashes the following commits: 1780351 [Zongheng Yang] Use size estimation to decide optimal build side of BNLJ. 68e6c5b [Zongheng Yang] Consolidate two adjacent pattern matchings. 96d312a [Zongheng Yang] Use a while loop instead of collection methods chaining. 4bc525e [Zongheng Yang] Make BroadcastNestedLoopJoin take a BuildSide. --- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../apache/spark/sql/execution/joins.scala | 79 ++++++++++++------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f1fe99f75c9d..d57b6eaf40b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -155,8 +155,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 2750ddbce896f..b068579db75cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -314,10 +314,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod */ @DeveloperApi case class BroadcastNestedLoopJoin( - streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - extends BinaryNode { + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + /** BuildRight means the right relation <=> the broadcast relation. */ + val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output = { @@ -333,11 +342,6 @@ case class BroadcastNestedLoopJoin( } } - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - @transient lazy val boundCondition = InterpretedPredicate( condition @@ -348,57 +352,78 @@ case class BroadcastNestedLoopJoin( val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 - var matched = false + var streamRowMatched = false while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - matched = true - includedBroadcastTuples += i + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => } i += 1 } - if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += joinedRow(streamedRow, rightNulls).copy() + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => } } Iterator((matchedRows, includedBroadcastTuples)) } - val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = if (includedBroadcastTuples.count == 0) { new scala.collection.mutable.BitSet(broadcastedRelation.value.size) } else { - streamedPlusMatches.map(_._2).reduce(_ ++ _) + includedBroadcastTuples.reduce(_ ++ _) } val leftNulls = new GenericMutableRow(left.output.size) - val rightOuterMatches: Seq[Row] = - if (joinType == RightOuter || joinType == FullOuter) { - broadcastedRelation.value.zipWithIndex.filter { - case (row, i) => !allIncludedBroadcastTuples.contains(i) - }.map { - case (row, _) => new JoinedRow(leftNulls, row) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case _ => + } } - } else { - Vector() + i += 1 } + arrBuf.toSeq + } // TODO: Breaks lineage. sparkContext.union( - streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) } } From d8430148ee1f6ba02569db0538eeae473a32c78e Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 31 Jul 2014 20:32:57 -0700 Subject: [PATCH 068/170] [SPARK-2724] Python version of RandomRDDGenerators RandomRDDGenerators but without support for randomRDD and randomVectorRDD, which take in arbitrary DistributionGenerator. `randomRDD.py` is named to avoid collision with the built-in Python `random` package. Author: Doris Xin Closes #1628 from dorx/pythonRDD and squashes the following commits: 55c6de8 [Doris Xin] review comments. all python units passed. f831d9b [Doris Xin] moved default args logic into PythonMLLibAPI 2d73917 [Doris Xin] fix for linalg.py 8663e6a [Doris Xin] reverting back to a single python file for random f47c481 [Doris Xin] docs update 687aac0 [Doris Xin] add RandomRDDGenerators.py to run-tests 4338f40 [Doris Xin] renamed randomRDD to rand and import as random 29d205e [Doris Xin] created mllib.random package bd2df13 [Doris Xin] typos 07ddff2 [Doris Xin] units passed. 23b2ecd [Doris Xin] WIP --- .../mllib/api/python/PythonMLLibAPI.scala | 97 ++++++++++ .../mllib/random/RandomRDDGenerators.scala | 90 +++++---- python/pyspark/__init__.py | 10 + python/pyspark/mllib/linalg.py | 4 + python/pyspark/mllib/random.py | 182 ++++++++++++++++++ python/run-tests | 1 + 6 files changed, 348 insertions(+), 36 deletions(-) create mode 100644 python/pyspark/mllib/random.py diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 954621ee8b933..d2e8ccf208970 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -24,10 +24,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -453,4 +455,99 @@ class PythonMLLibAPI extends Serializable { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + + // Used by the *RDD methods to get default seed if not passed in from pyspark + private def getSeedOrDefault(seed: java.lang.Long): Long = { + if (seed == null) Utils.random.nextLong else seed + } + + // Used by *RDD methods to get default numPartitions if not passed in from pyspark + private def getNumPartitionsOrDefault(numPartitions: java.lang.Integer, + jsc: JavaSparkContext): Int = { + if (numPartitions == null) { + jsc.sc.defaultParallelism + } else { + numPartitions + } + } + + // Note: for the following methods, numPartitions and seed are boxed to allow nulls to be passed + // in for either argument from pyspark + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformRDD() + */ + def uniformRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalRDD() + */ + def normalRDD(jsc: JavaSparkContext, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonRDD() + */ + def poissonRDD(jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.uniformVectorRDD() + */ + def uniformVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.normalVectorRDD() + */ + def normalVectorRDD(jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + } + + /** + * Java stub for Python mllib RandomRDDGenerators.poissonVectorRDD() + */ + def poissonVectorRDD(jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: java.lang.Integer, + seed: java.lang.Long): JavaRDD[Array[Byte]] = { + val parts = getNumPartitionsOrDefault(numPartitions, jsc) + val s = getSeedOrDefault(seed) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index d7ee2d3f46846..021d651d4dbaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -26,14 +26,17 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -49,7 +52,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -63,9 +69,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use + * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. @@ -77,7 +86,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -93,7 +105,10 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. @@ -107,9 +122,12 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the standard normal distribution. + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * + * To transform the distribution in the generated RDD from standard normal to some other normal + * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. + * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). @@ -121,7 +139,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -142,7 +160,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. @@ -157,7 +175,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -172,7 +190,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -192,7 +210,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * * @param sc SparkContext used to create the RDD. * @param generator DistributionGenerator used to populate the RDD. @@ -210,7 +228,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator. + * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. @@ -229,7 +247,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. @@ -251,14 +269,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, @@ -270,14 +288,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * uniform distribution on [0.0 1.0]. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. */ @Experimental def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -286,7 +304,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -294,7 +312,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -308,14 +326,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, @@ -327,14 +345,14 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). */ @Experimental def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { @@ -343,7 +361,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -352,7 +370,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -367,7 +385,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -375,7 +393,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -388,7 +406,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -396,7 +414,7 @@ object RandomRDDGenerators { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ @Experimental def poissonVectorRDD(sc: SparkContext, @@ -408,7 +426,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -417,7 +435,7 @@ object RandomRDDGenerators { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -431,7 +449,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * * @param sc SparkContext used to create the RDD. @@ -439,7 +457,7 @@ object RandomRDDGenerators { * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, @@ -452,7 +470,7 @@ object RandomRDDGenerators { /** * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the * input DistributionGenerator. * sc.defaultParallelism used for the number of partitions in the RDD. * @@ -460,7 +478,7 @@ object RandomRDDGenerators { * @param generator DistributionGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d samples produced by generator. + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ @Experimental def randomVectorRDD(sc: SparkContext, diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 312c75d112cbf..c58555fc9d2c5 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -49,6 +49,16 @@ Main entry point for accessing data stored in Apache Hive.. """ +# The following block allows us to import python's random instead of mllib.random for scripts in +# mllib that depend on top level pyspark packages, which transitively depend on python's random. +# Since Python's import logic looks for modules in the current package first, we eliminate +# mllib.random as a candidate for C{import random} by removing the first search path, the script's +# location, in order to force the loader to look in Python's top-level modules for C{random}. +import sys +s = sys.path.pop(0) +import random +sys.path.insert(0, s) + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.sql import SQLContext diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 71f4ad1a8d44e..54720c2324ca6 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -255,4 +255,8 @@ def _test(): exit(-1) if __name__ == "__main__": + # remove current path from list of search paths to avoid importing mllib.random + # for C{import random}, which is done in an external dependency of pyspark during doctests. + import sys + sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py new file mode 100644 index 0000000000000..36e710dbae7a8 --- /dev/null +++ b/python/pyspark/mllib/random.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + + +from pyspark.rdd import RDD +from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector +from pyspark.serializers import NoOpSerializer + +class RandomRDDGenerators: + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution on [0.0, 1.0]. + + To transform the distribution in the generated RDD from U[0.0, 1.0] + to U[a, b], use + C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma), use + C{RandomRDDGenerators.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d samples from the Poisson + distribution with the input mean. + + >>> mean = 100.0 + >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) + + @staticmethod + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the uniform distribution on [0.0 1.0]. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + uniform = RDD(jrdd, sc, NoOpSerializer()) + return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the standard normal distribution. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + normal = RDD(jrdd, sc, NoOpSerializer()) + return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + @staticmethod + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d samples drawn + from the Poisson distribution with the input mean. + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + jrdd = sc._jvm.PythonMLLibAPI() \ + .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + poisson = RDD(jrdd, sc, NoOpSerializer()) + return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index 29f755fc0dcd3..5049e15ce5f8a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -67,6 +67,7 @@ run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" From b124de584a45b7ebde9fbe10128db429c56aeaee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 31 Jul 2014 20:51:48 -0700 Subject: [PATCH 069/170] [SPARK-2756] [mllib] Decision tree bug fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (1) Inconsistent aggregate (agg) indexing for unordered features. (2) Fixed gain calculations for edge cases. (3) One-off error in choosing thresholds for continuous features for small datasets. (4) (not a bug) Changed meaning of tree depth by 1 to fit scikit-learn and rpart. (Depth 1 used to mean 1 leaf node; depth 0 now means 1 leaf node.) Other updates, to help with tests: * Updated DecisionTreeRunner to print more info. * Added utility functions to DecisionTreeModel: toString, depth, numNodes * Improved internal DecisionTree documentation Bug fix details: (1) Indexing was inconsistent for aggregate calculations for unordered features (in multiclass classification with categorical features, where the features had few enough values such that they could be considered unordered, i.e., isSpaceSufficientForAllCategoricalSplits=true). * updateBinForUnorderedFeature indexed agg as (node, feature, featureValue, binIndex), where ** featureValue was from arr (so it was a feature value) ** binIndex was in [0,…, 2^(maxFeatureValue-1)-1) * The rest of the code indexed agg as (node, feature, binIndex, label). * Corrected this bug by changing updateBinForUnorderedFeature to use the second indexing pattern. Unit tests in DecisionTreeSuite * Updated a few tests to train a model and test its training accuracy, which catches the indexing bug from updateBinForUnorderedFeature() discussed above. * Added new test (“stump with categorical variables for multiclass classification, with just enough bins”) to test bin extremes. (2) Bug fix: calculateGainForSplit (for classification): * It used to return dummy prediction values when either the right or left children had 0 weight. These were incorrect for multiclass classification. It has been corrected. Updated impurities to allow for count = 0. This was related to the above bug fix for calculateGainForSplit (for classification). Small updates to documentation and coding style. (3) Bug fix: Off-by-1 when finding thresholds for splits for continuous features. * Exhibited bug in new test in DecisionTreeSuite: “stump with 1 continuous variable for binary classification, to check off-by-1 error” * Description: When finding thresholds for possible splits for continuous features in DecisionTree.findSplitsBins, the thresholds were set according to individual training examples’ feature values. * Fix: The threshold is set to be the average of 2 consecutive (sorted) examples’ feature values. E.g.: If the old code set the threshold using example i, the new code sets the threshold using exam * Note: In 4 DecisionTreeSuite tests with all labels identical, removed check of threshold since it is somewhat arbitrary. CC: mengxr manishamde Please let me know if I missed something! Author: Joseph K. Bradley Closes #1673 from jkbradley/decisiontree-bugfix and squashes the following commits: 2b20c61 [Joseph K. Bradley] Small doc and style updates dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. --- .../examples/mllib/DecisionTreeRunner.scala | 92 +++- .../spark/mllib/tree/DecisionTree.scala | 408 +++++++++++------- .../mllib/tree/configuration/Strategy.scala | 7 +- .../spark/mllib/tree/impurity/Entropy.scala | 6 +- .../spark/mllib/tree/impurity/Gini.scala | 6 +- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../spark/mllib/tree/impurity/Variance.scala | 6 +- .../mllib/tree/model/DecisionTreeModel.scala | 31 +- .../apache/spark/mllib/tree/model/Node.scala | 56 +++ .../spark/mllib/tree/DecisionTreeSuite.scala | 115 ++++- 10 files changed, 538 insertions(+), 193 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 6db9bf3cf5be6..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,7 +21,6 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} @@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. */ object DecisionTreeRunner { @@ -48,11 +50,12 @@ object DecisionTreeRunner { case class Params( input: String = null, + dataFormat: String = "libsvm", algo: Algo = Classification, - numClassesForClassification: Int = 2, - maxDepth: Int = 5, + maxDepth: Int = 4, impurity: ImpurityType = Gini, - maxBins: Int = 100) + maxBins: Int = 100, + fracTest: Double = 0.2) def main(args: Array[String]) { val defaultParams = Params() @@ -69,25 +72,31 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClassesForClassification") - .text(s"number of classes for classification, " - + s"default: ${defaultParams.numClassesForClassification}") - .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) arg[String]("") .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.algo == Classification && - (params.impurity == Gini || params.impurity == Entropy)) { - success - } else if (params.algo == Regression && params.impurity == Variance) { - success + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") } else { - failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } } } } @@ -100,16 +109,57 @@ object DecisionTreeRunner { } def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() + val origExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + } + // For classification, re-index classes if needed. + val (examples, numClasses) = params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val classCounts = origExamples.map(_.label).countByValue() + val sortedClasses = classCounts.keys.toList.sorted + val numClasses = classCounts.size + // classIndexMap: class --> index in 0,...,numClasses-1 + val classIndexMap = { + if (classCounts.keySet != Set(0.0, 1.0)) { + sortedClasses.zipWithIndex.toMap + } else { + Map[Double, Int]() + } + } + val examples = { + if (classIndexMap.isEmpty) { + origExamples + } else { + origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + val numExamples = examples.count() + println(s"numClasses = $numClasses.") + println(s"Per-class example fractions, counts:") + println(s"Class\tFrac\tCount") + sortedClasses.foreach { c => + val frac = classCounts(c) / numExamples.toDouble + println(s"$c\t$frac\t${classCounts(c)}") + } + (examples, numClasses) + } + case Regression => + (origExamples, 0) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } - val splits = examples.randomSplit(Array(0.8, 0.2)) + // Split into training, test. + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() - val numTraining = training.count() val numTest = test.count() @@ -129,17 +179,19 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + numClassesForClassification = numClasses) val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy.") + println(s"Test accuracy = $accuracy") } if (params.algo == Regression) { val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse.") + println(s"Test mean squared error = $mse") } sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad32e3f4560fe..7d123dd6ae996 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * :: Experimental :: - * A class that implements a decision tree algorithm for classification and regression. It - * supports both continuous and categorical features. + * A class which implements a decision tree learning algorithm for classification and regression. + * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. @@ -42,8 +42,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /** * Method to train a decision tree model over an RDD - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * @return a DecisionTreeModel that can be used for prediction + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -100,7 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var level = 0 var break = false - while (level < maxDepth && !break) { + while (level <= maxDepth && !break) { logDebug("#####################################") logDebug("level = " + level) @@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node @@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { + if (level < maxDepth) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo object DecisionTree extends Serializable with Logging { /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. The parameters for the algorithm are specified using the strategy parameter. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).train(input) @@ -219,12 +218,14 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @return a DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -241,13 +242,15 @@ object DecisionTree extends Serializable with Logging { * binary classification, the label for each instance should either be 0 or 1 to denote the two * classes. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -266,11 +269,13 @@ object DecisionTree extends Serializable with Logging { * 1 to denote the two classes. The method also supports categorical features inputs where the * number of categories can specified using the categoricalFeaturesInfo option. * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data for DecisionTree + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles @@ -279,7 +284,7 @@ object DecisionTree extends Serializable with Logging { * an entry (n -> k) implies the feature n is categorical with k * categories 0, 1, 2, ... , k-1. It's important to note that * features are zero-indexed. - * @return a DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction */ def train( input: RDD[LabeledPoint], @@ -301,11 +306,10 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -348,11 +352,10 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree + * parameters for constructing the DecisionTree * @param level Level of the tree * @param filters Filters for all nodes at a given level * @param splits possible splits for all features @@ -373,7 +376,7 @@ object DecisionTree extends Serializable with Logging { groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* - * The high-level description for the best split optimizations are noted here. + * The high-level descriptions of the best split optimizations are noted here. * * *Level-wise training* * We perform bin calculations for all nodes at the given level to avoid making multiple @@ -396,18 +399,27 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // common calculations for multiple nested methods + // Common calculations for multiple nested methods: + + // numNodes: Number of nodes in this (level of tree, group), + // where nodes at deeper (larger) levels may be divided into groups. val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) + + // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures logDebug("isMultiClassWithCategoricalFeatures = " + @@ -465,10 +477,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Find bin for one feature. + * Find bin for one (labeledPoint, feature). */ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -535,7 +550,9 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { sequentialBinSearchForOrderedCategoricalFeatureInClassification() @@ -555,6 +572,14 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1 for regressions and binary * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. + * + * For unordered features, the "bin index" returned is actually the feature value (category). + * + * @return Array of size 1 + numFeatures * numNodes, where + * arr(0) = label for labeledPoint, and + * arr(1 + numFeatures * nodeIndex + featureIndex) = + * bin index for this labeledPoint + * (or InvalidBinIndex if labeledPoint is not handled by this node) */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. @@ -598,9 +623,21 @@ object DecisionTree extends Serializable with Logging { // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int) = { - + /** + * Increment aggregate in location for (node, feature, bin, label). + * + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. + */ + def updateBinForOrderedFeature( + arr: Array[Double], + agg: Array[Double], + nodeIndex: Int, + label: Double, + featureIndex: Int): Unit = { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex @@ -612,44 +649,58 @@ object DecisionTree extends Serializable with Logging { agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } - def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], - label: Double, agg: Array[Double], rightChildShift: Int) = { + /** + * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), + * where [bins] ranges over all bins. + * Updates left or right side of aggregate depending on split. + * + * @param arr arr(0) = label. + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param agg Indexed by (left/right, node, feature, bin, label) + * where label is the least significant bit. + * The left/right specifier is a 0/1 index indicating left/right child info. + * @param rightChildShift Offset for right side of agg. + */ + def updateBinForUnorderedFeature( + nodeIndex: Int, + featureIndex: Int, + arr: Array[Double], + label: Double, + agg: Array[Double], + rightChildShift: Int): Unit = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + val featureValue = arr(arrIndex).toInt // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggShift = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { - val labelInt = label.toInt - if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + 1 + val aggIndex = aggShift + binIndex * numClasses + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(aggIndex) += 1 } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + 1 + agg(rightChildShift + aggIndex) += 1 } binIndex += 1 } } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * @param agg Array storing aggregate calculation, of size: + * numClasses * numBins * numFeatures * numNodes. + * Indexed by (node, feature, bin, label) where label is the least significant bit. */ - def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -671,17 +722,21 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. + * Helper for binSeqOp. * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification + * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. + * Array of size 1 + (numFeatures * numNodes). + * For ordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. + * For unordered features, + * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. */ - def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -717,16 +772,17 @@ object DecisionTree extends Serializable with Logging { } /** - * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. * - * @param agg Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 3 * numSplits * numFeatures * numNodes for regression + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -757,14 +813,30 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. + * For l nodes, k features, + * For classification: + * Either the left count or the right count of one of the bins is + * incremented based upon whether the feature is classified as 0 or 1. + * For regression: + * The count, sum, sum of squares of one of the bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size for classification: + * numClasses * numBins * numFeatures * numNodes for ordered features, or + * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Size for regression: + * 3 * numBins * numFeatures * numNodes. + * @param arr Bin mapping from findBinsForLevel. + * Array of size 1 + (numFeatures * numNodes). + * @return agg */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - unorderedClassificationBinSeqOp(arr, agg) + multiclassWithCategoricalBinSeqOp(arr, agg) } else { - orderedClassificationBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } @@ -815,20 +887,10 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - var classIndex = 0 - val leftCounts: Array[Double] = new Array[Double](numClasses) - val rightCounts: Array[Double] = new Array[Double](numClasses) - var leftTotalCount = 0.0 - var rightTotalCount = 0.0 - while (classIndex < numClasses) { - val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) - val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) - leftCounts(classIndex) = leftClassCount - leftTotalCount += leftClassCount - rightCounts(classIndex) = rightClassCount - rightTotalCount += rightClassCount - classIndex += 1 - } + val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) + val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) + val leftTotalCount = leftCounts.sum + val rightTotalCount = rightCounts.sum val impurity = { if (level > 0) { @@ -845,33 +907,17 @@ object DecisionTree extends Serializable with Logging { } } - if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) - } - if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) - } - - val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) - val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - - val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) - - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } // Sum of count for each label - val leftRightCounts: Array[Double] - = leftCounts.zip(rightCounts) - .map{case (leftCount, rightCount) => leftCount + rightCount} + val leftRightCounts: Array[Double] = + leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + leftCount + rightCount + } def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { @@ -885,6 +931,22 @@ object DecisionTree extends Serializable with Logging { val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(leftCounts, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + strategy.impurity.calculate(rightCounts, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) @@ -937,10 +999,18 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. - * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], - * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, - * (numBins - 1), numClasses) + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). + * For regression, each array is of size (numFeatures, (numBins - 1), 3). + * */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { @@ -983,6 +1053,11 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value + */ def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1107,7 +1182,7 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. - * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param binData Bin data slice for this node, given by getBinDataForNode. * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ @@ -1133,7 +1208,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex : Double = { + val maxSplitIndex: Double = { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { numBins - 1 @@ -1162,8 +1237,8 @@ object DecisionTree extends Serializable with Logging { (bestFeatureIndex, bestSplitIndex, bestGainStats) } + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } @@ -1214,8 +1289,17 @@ object DecisionTree extends Serializable with Logging { bestSplits } - private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + /** + * Get the number of values to be stored per node in the bin aggregates. + * + * @param numBins Number of bins = 1 + number of possible splits. + */ + private def getElementsPerNode( + numFeatures: Int, + numBins: Int, + numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, + algo: Algo): Int = { algo match { case Classification => if (isMulticlassClassificationWithCategoricalFeatures) { @@ -1228,18 +1312,40 @@ object DecisionTree extends Serializable with Logging { } /** - * Returns split and bins for decision tree calculation. - * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * There are 2^(maxFeatureValue - 1) - 1 splits. + * (b) For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one split per category. + + * Categorical case (a) features are called unordered features. + * Other cases are called ordered features. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + * parameters for construction the DecisionTree + * @return A tuple of (splits,bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numBins - 1). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() // Find the number of features by looking at the first sample @@ -1271,7 +1377,8 @@ object DecisionTree extends Serializable with Logging { logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val sampledInput = + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length val stride: Double = numSamples.toDouble / numBins @@ -1294,8 +1401,10 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) for (index <- 0 until numBins - 1) { - val sampleIndex = (index + 1) * stride.toInt - val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + val sampleIndex = index * stride.toInt + // Set threshold halfway in between 2 samples. + val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + val split = new Split(featureIndex, threshold, Continuous, List()) splits(featureIndex)(index) = split } } else { // Categorical feature @@ -1304,8 +1413,10 @@ object DecisionTree extends Serializable with Logging { = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // classification that satisfy the space constraint. + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + if (isUnorderedFeature) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1330,8 +1441,13 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { - + } else { // ordered feature + /* For a given categorical feature, use a subsample of the data + * to choose how to arrange possible splits. + * This examines each category and computes a centroid. + * These centroids are later used to sort the possible splits. + * centroidForCategories is a mapping: category (for the given feature) --> centroid + */ val centroidForCategories = { if (isMulticlassClassification) { // For categorical variables in multiclass classification, @@ -1341,7 +1457,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1352,7 +1468,7 @@ object DecisionTree extends Serializable with Logging { } } - logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + logDebug("centroid for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() @@ -1367,7 +1483,7 @@ object DecisionTree extends Serializable with Logging { // bins sorted by centroids val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7c027ac2fda6b..5c65b537b6867 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,7 +27,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Stores all the configuration options for tree construction * @param algo classification or regression * @param impurity criterion used for information gain calculation - * @param maxDepth maximum depth of the tree + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClassesForClassification number of classes for classification. Default value is 2 * leads to binary classification * @param maxBins maximum number of bins used for splitting features @@ -52,7 +53,9 @@ class Strategy ( val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), val maxMemoryInMB: Int = 128) extends Serializable { - require(numClassesForClassification >= 2) + if (algo == Classification) { + require(numClassesForClassification >= 2) + } val isMulticlassClassification = numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a0e2d91762782..9297c20596527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -34,10 +34,13 @@ object Entropy extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 @@ -58,6 +61,7 @@ object Entropy extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 48144b5e6d1e4..2874bcf496484 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -33,10 +33,13 @@ object Gini extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { + if (totalCount == 0) { + return 0 + } val numClasses = counts.length var impurity = 1.0 var classIndex = 0 @@ -54,6 +57,7 @@ object Gini extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 7b2a9320cc21d..92b0c7b4a6fbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -31,7 +31,7 @@ trait Impurity extends Serializable { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -42,7 +42,7 @@ trait Impurity extends Serializable { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return information value + * @return information value, or 0 if count = 0 */ @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 97149a99ead59..698a1a2a8e899 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -31,7 +31,7 @@ object Variance extends Impurity { * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels - * @return information value + * @return information value, or 0 if totalCount = 0 */ @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = @@ -43,9 +43,13 @@ object Variance extends Impurity { * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 */ @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + if (count == 0) { + return 0 + } val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index bf692ca8c4bd7..3d3406b5d5f22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: - * Model to store the decision tree parameters + * Decision tree model for classification or regression. + * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ @@ -50,4 +51,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } + + /** + * Get number of nodes in tree, including leaf nodes. + */ + def numNodes: Int = { + 1 + topNode.numDescendants + } + + /** + * Get depth of tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + def depth: Int = { + topNode.subtreeDepth + } + + /** + * Print full model. + */ + override def toString: String = algo match { + case Classification => + s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + case Regression => + s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + case _ => throw new IllegalArgumentException( + s"DecisionTreeModel given unknown algo parameter: $algo.") + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 682f213f411a7..944f11c2c2e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -91,4 +91,60 @@ class Node ( } } } + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int = { + if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants + } + } + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. + */ + private[tree] def subtreeDepth: Int = { + if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) + } + } + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String = { + + def splitToString(split: Split, left: Boolean): String = { + split.featureType match { + case Continuous => if (left) { + s"(feature ${split.feature} <= ${split.threshold})" + } else { + s"(feature ${split.feature} > ${split.threshold})" + } + case Categorical => if (left) { + s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + } else { + s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + } + } + } + val prefix: String = " " * indentFactor + if (isLeaf) { + prefix + s"Predict: $predict\n" + } else { + prefix + s"If ${splitToString(split.get, left=true)}\n" + + leftNode.get.subtreeToString(indentFactor + 1) + + prefix + s"Else ${splitToString(split.get, left=false)}\n" + + rightNode.get.subtreeToString(indentFactor + 1) + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5961a618c59d9..10462db700628 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -31,6 +30,18 @@ import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -50,7 +61,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) @@ -130,7 +141,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -236,7 +247,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("extract categories from a number for multiclass classification") { val l = DecisionTree.extractMultiClassCategories(13, 10) assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } test("split and bin calculations for unordered categorical variables with multiclass " + @@ -247,7 +258,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -341,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, - maxDepth = 3, + maxDepth = 2, numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) @@ -397,7 +408,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, numClassesForClassification = 2, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -413,7 +424,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict === 1) - assert(stats.prob == 0.6) + assert(stats.prob === 0.6) assert(stats.impurity > 0.2) } @@ -424,7 +435,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Regression, Variance, - maxDepth = 3, + maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) @@ -439,7 +450,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict == 0.6) + assert(stats.predict === 0.6) assert(stats.impurity > 0.2) } @@ -460,7 +471,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -483,7 +493,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -507,7 +516,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -531,7 +539,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._1.threshold === 10) assert(bestSplits(0)._2.gain === 0) assert(bestSplits(0)._2.leftImpurity === 0) assert(bestSplits(0)._2.rightImpurity === 0) @@ -587,7 +594,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) @@ -602,12 +609,78 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) + arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + + test("stump with 2 continuous variables for binary classification") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 2) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + assert(model.topNode.split.get.feature === 1) + } + + test("stump with categorical variables for multiclass classification, with just enough bins") { + val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + val gain = bestSplits(0)._2 + assert(gain.leftImpurity === 0) + assert(gain.rightImpurity === 0) + } + test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -625,9 +698,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 0.9) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) @@ -644,7 +721,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val input = sc.parallelize(arr) - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) From 9632719c9ef16ad95af4f3b85ae72d54b02b0f90 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 31 Jul 2014 21:02:11 -0700 Subject: [PATCH 070/170] [SPARK-2779] [SQL] asInstanceOf[Map[...]] should use scala.collection.Map instead of scala.collection.immutable.Map Since we let users create Rows. It makes sense to accept mutable Maps as values of MapType columns. JIRA: https://issues.apache.org/jira/browse/SPARK-2779 Author: Yin Huai Closes #1705 from yhuai/SPARK-2779 and squashes the following commits: 00d72fd [Yin Huai] Use scala.collection.Map. --- .../catalyst/expressions/complexTypes.scala | 2 ++ .../sql/catalyst/expressions/generators.scala | 2 ++ .../org/apache/spark/sql/json/JsonRDD.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +++++++++++++++++++ 4 files changed, 24 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 72add5e20e8b4..c1154eb81c319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 422839dab770d..3d41acb79e5fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.Map + import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index bd29ee421bbc4..70db1ebd3a3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.json +import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bebb490645420..5c571d35d1bb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -505,5 +505,24 @@ class SQLQuerySuite extends QueryTest { (2, null) :: (3, null) :: (4, 2147483644) :: Nil) + + // The value of a MapType column can be a mutable map. + val rowRDD3 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) + } + + val schemaRDD3 = applySchema(rowRDD3, schema2) + schemaRDD3.registerAsTable("applySchema3") + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) } } From 9998efab96a4fdc927818eaae53c04f946c4cf13 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 31 Jul 2014 21:06:57 -0700 Subject: [PATCH 071/170] SPARK-2766: ScalaReflectionSuite throw an llegalArgumentException in JDK 6 Author: GuoQiang Li Closes #1683 from witgo/SPARK-2766 and squashes the following commits: d0db00c [GuoQiang Li] ScalaReflectionSuite throw an llegalArgumentException in JDK 6 --- .../org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index e030d6e13d472..e75373d5a74a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -182,7 +182,7 @@ class ScalaReflectionSuite extends FunSuite { assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) // TimestampType - assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00"))) + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) // NullType assert(NullType === typeOfObject(null)) From b19008320bdf7064e764db04c43ef003a3ce0ecd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 31 Jul 2014 21:14:08 -0700 Subject: [PATCH 072/170] [SPARK-2777][MLLIB] change ALS factors storage level to MEMORY_AND_DISK Now the factors are persisted in memory only. If they get kicked off by later jobs, we might have to start the computation from very beginning. A better solution is changing the storage level to `MEMORY_AND_DISK`. srowen Author: Xiangrui Meng Closes #1700 from mengxr/als-level and squashes the following commits: c103d76 [Xiangrui Meng] change ALS factors storage level to MEMORY_AND_DISK --- .../scala/org/apache/spark/mllib/recommendation/ALS.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index d208cfb917f3d..36d262fed425a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -290,8 +290,8 @@ class ALS private ( val usersOut = unblockFactors(users, userOutLinks) val productsOut = unblockFactors(products, productOutLinks) - usersOut.setName("usersOut").persist() - productsOut.setName("productsOut").persist() + usersOut.setName("usersOut").persist(StorageLevel.MEMORY_AND_DISK) + productsOut.setName("productsOut").persist(StorageLevel.MEMORY_AND_DISK) // Materialize usersOut and productsOut. usersOut.count() From c4755403e7d670176d81211813b6515dec76bee2 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 31 Jul 2014 21:23:35 -0700 Subject: [PATCH 073/170] [SPARK-2782][mllib] Bug fix for getRanks in SpearmanCorrelation getRanks computes the wrong rank when numPartition >= size in the input RDDs before this patch. added units to address this bug. Author: Doris Xin Closes #1710 from dorx/correlationBug and squashes the following commits: 733def4 [Doris Xin] bugs and reviewer comments. 31db920 [Doris Xin] revert unnecessary change 043ff83 [Doris Xin] bug fix for spearman corner case --- .../apache/spark/mllib/stat/Statistics.scala | 22 ++++++++++------ .../correlation/SpearmanCorrelation.scala | 18 ++++++------- .../spark/mllib/stat/CorrelationSuite.scala | 25 +++++++++++++++++++ 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 68f3867ba6c11..9d6de9b6e1f60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -30,7 +30,7 @@ object Statistics { /** * Compute the Pearson correlation matrix for the input RDD of Vectors. - * Returns NaN if either vector has 0 variance. + * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. @@ -39,7 +39,7 @@ object Statistics { /** * Compute the correlation matrix for the input RDD of Vectors using the specified method. - * Methods currently supported: `pearson` (default), `spearman` + * Methods currently supported: `pearson` (default), `spearman`. * * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], @@ -55,20 +55,26 @@ object Statistics { /** * Compute the Pearson correlation for the input RDDs. - * Columns with 0 covariance produce NaN entries in the correlation matrix. + * Returns NaN if either vector has 0 variance. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Compute the correlation for the input RDDs using the specified method. - * Methods currently supported: pearson (default), spearman + * Methods currently supported: `pearson` (default), `spearman`. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` *@return A Double containing the correlation between the two input RDD[Double]s using the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 1f7de630e778c..9bd0c2cd05de4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -89,20 +89,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => // add an extra element to signify the end of the list so that flatMap can flush the last // batch of duplicates - val padded = iter ++ - Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) - var lastVal = 0.0 - var firstRank = 0.0 - val idBuffer = new ArrayBuffer[Long]() + val end = -1L + val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) + val firstEntry = padded.next() + var lastVal = firstEntry._1._1 + var firstRank = firstEntry._2.toDouble + val idBuffer = ArrayBuffer(firstEntry._1._2) padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != Long.MinValue) { + if (v == lastVal && id != end) { idBuffer += id Iterator.empty } else { - val entries = if (idBuffer.size == 0) { - // edge case for the first value matching the initial value of lastVal - Iterator.empty - } else if (idBuffer.size == 1) { + val entries = if (idBuffer.size == 1) { Iterator((idBuffer(0), firstRank)) } else { val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index bce4251426df7..a3f76f77a5dcc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -31,6 +31,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -46,6 +47,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val p1 = Statistics.corr(x, y, "pearson") assert(approxEqual(expected, default)) assert(approxEqual(expected, p1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val p2 = Statistics.corr(x1, y1) + assert(approxEqual(expected, p2)) + } + + // RDD of zero variance + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z).isNaN()) } test("corr(x, y) spearman") { @@ -54,6 +67,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val expected = 0.5 val s1 = Statistics.corr(x, y, "spearman") assert(approxEqual(expected, s1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val s2 = Statistics.corr(x1, y1, "spearman") + assert(approxEqual(expected, s2)) + } + + // RDD of zero variance => zero variance in ranks + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z, "spearman").isNaN()) } test("corr(X) default, pearson") { From 2cdc3e5c6f5601086590a0cebf40a48f7560d02e Mon Sep 17 00:00:00 2001 From: Haoyuan Li Date: Thu, 31 Jul 2014 22:53:42 -0700 Subject: [PATCH 074/170] [SPARK-2702][Core] Upgrade Tachyon dependency to 0.5.0 Author: Haoyuan Li Closes #1651 from haoyuan/upgrade-tachyon and squashes the following commits: 6f3f98f [Haoyuan Li] upgrade tachyon to 0.5.0 --- core/pom.xml | 4 ++-- make-distribution.sh | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 04d4b9cc1068e..7c60cf10c3dc2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -192,8 +192,8 @@ org.tachyonproject - tachyon - 0.4.1-thrift + tachyon-client + 0.5.0 org.apache.hadoop diff --git a/make-distribution.sh b/make-distribution.sh index 0a3283ecec6f8..1441497b3995a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -128,7 +128,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then if [[ ! $REPLY =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 - fi + fi fi if [ "$NAME" == "none" ]; then @@ -173,7 +173,7 @@ cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" @@ -199,7 +199,7 @@ cp -r "$FWDIR/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.4.1" + TACHYON_VERSION="0.5.0" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` From 149910111331133d52e0cb01b256f7f731b436ad Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 31 Jul 2014 22:57:13 -0700 Subject: [PATCH 075/170] SPARK-2632, SPARK-2576. Fixed by only importing what is necessary during class definition. Without this patch, it imports everything available in the scope. ```scala scala> val a = 10l val a = 10l a: Long = 10 scala> import a._ import a._ import a._ scala> case class A(a: Int) // show case class A(a: Int) // show class $read extends Serializable { def () = { super.; () }; class $iwC extends Serializable { def () = { super.; () }; class $iwC extends Serializable { def () = { super.; () }; import org.apache.spark.SparkContext._; class $iwC extends Serializable { def () = { super.; () }; val $VAL5 = $line5.$read.INSTANCE; import $VAL5.$iw.$iw.$iw.$iw.a; class $iwC extends Serializable { def () = { super.; () }; import a._; class $iwC extends Serializable { def () = { super.; () }; class $iwC extends Serializable { def () = { super.; () }; case class A extends scala.Product with scala.Serializable { val a: Int = _; def (a: Int) = { super.; () } } }; val $iw = new $iwC. }; val $iw = new $iwC. }; val $iw = new $iwC. }; val $iw = new $iwC. }; val $iw = new $iwC. }; val $iw = new $iwC. } object $read extends scala.AnyRef { def () = { super.; () }; val INSTANCE = new $read. } defined class A ``` With this patch, it just imports only the necessary. ```scala scala> val a = 10l val a = 10l a: Long = 10 scala> import a._ import a._ import a._ scala> case class A(a: Int) // show case class A(a: Int) // show class $read extends Serializable { def () = { super.; () }; class $iwC extends Serializable { def () = { super.; () }; class $iwC extends Serializable { def () = { super.; () }; case class A extends scala.Product with scala.Serializable { val a: Int = _; def (a: Int) = { super.; () } } }; val $iw = new $iwC. }; val $iw = new $iwC. } object $read extends scala.AnyRef { def () = { super.; () }; val INSTANCE = new $read. } defined class A scala> ``` This patch also adds a `:fallback` mode on being enabled it will restore the spark-shell's 1.0.0 behaviour. Author: Prashant Sharma Author: Yin Huai Author: Prashant Sharma Closes #1635 from ScrapCodes/repl-fix-necessary-imports and squashes the following commits: b1968d2 [Prashant Sharma] Added toschemaRDD to test case. 0b712bb [Yin Huai] Add a REPL test to test importing a method. 02ad8ff [Yin Huai] Add a REPL test for importing SQLContext.createSchemaRDD. ed6d0c7 [Prashant Sharma] Added a fallback mode, incase users run into issues while using repl. b63d3b2 [Prashant Sharma] SPARK-2632, SPARK-2576. Fixed by only importing what is necessary during class definition. --- repl/pom.xml | 6 +++++ .../org/apache/spark/repl/SparkILoop.scala | 17 ++++++++++++ .../org/apache/spark/repl/SparkIMain.scala | 7 ++++- .../org/apache/spark/repl/SparkImports.scala | 15 ++++++++--- .../org/apache/spark/repl/ReplSuite.scala | 27 +++++++++++++++++++ 5 files changed, 67 insertions(+), 5 deletions(-) diff --git a/repl/pom.xml b/repl/pom.xml index 4ebb1b82f0e8c..68f4504450778 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -55,6 +55,12 @@ ${project.version} runtime + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + org.eclipse.jetty jetty-server diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 6f9fa0d9f2b25..42c7e511dc3f5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -230,6 +230,20 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, case xs => xs find (_.name == cmd) } } + private var fallbackMode = false + + private def toggleFallbackMode() { + val old = fallbackMode + fallbackMode = !old + System.setProperty("spark.repl.fallback", fallbackMode.toString) + echo(s""" + |Switched ${if (old) "off" else "on"} fallback mode without restarting. + | If you have defined classes in the repl, it would + |be good to redefine them incase you plan to use them. If you still run + |into issues it would be good to restart the repl and turn on `:fallback` + |mode as first command. + """.stripMargin) + } /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { @@ -299,6 +313,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), + nullary("fallback", """ + |disable/enable advanced repl changes, these fix some issues but may introduce others. + |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) ) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 3842c291d0b7b..f60bbb4662af1 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -892,11 +892,16 @@ import org.apache.spark.util.Utils def definedTypeSymbol(name: String) = definedSymbols(newTypeName(name)) def definedTermSymbol(name: String) = definedSymbols(newTermName(name)) + val definedClasses = handlers.exists { + case _: ClassHandler => true + case _ => false + } + /** Code to import bound names from previous lines - accessPath is code to * append to objectName to access anything bound by request. */ val SparkComputedImports(importsPreamble, importsTrailer, accessPath) = - importsCode(referencedNames.toSet) + importsCode(referencedNames.toSet, definedClasses) /** Code to access a variable with the specified name */ def fullPath(vname: String) = { diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 9099e052f5796..193a42dcded12 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -108,8 +108,9 @@ trait SparkImports { * last one imported is actually usable. */ case class SparkComputedImports(prepend: String, append: String, access: String) + def fallback = System.getProperty("spark.repl.fallback", "false").toBoolean - protected def importsCode(wanted: Set[Name]): SparkComputedImports = { + protected def importsCode(wanted: Set[Name], definedClass: Boolean): SparkComputedImports = { /** Narrow down the list of requests from which imports * should be taken. Removes requests which cannot contribute * useful imports for the specified set of wanted names. @@ -124,8 +125,14 @@ trait SparkImports { // Single symbol imports might be implicits! See bug #1752. Rather than // try to finesse this, we will mimic all imports for now. def keepHandler(handler: MemberHandler) = handler match { - case _: ImportHandler => true - case x => x.definesImplicit || (x.definedNames exists wanted) + /* This case clause tries to "precisely" import only what is required. And in this + * it may miss out on some implicits, because implicits are not known in `wanted`. Thus + * it is suitable for defining classes. AFAIK while defining classes implicits are not + * needed.*/ + case h: ImportHandler if definedClass && !fallback => + h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x => x.definesImplicit || (x.definedNames exists wanted) } reqs match { @@ -182,7 +189,7 @@ trait SparkImports { // ambiguity errors will not be generated. Also, quote // the name of the variable, so that we don't need to // handle quoting keywords separately. - case x: ClassHandler => + case x: ClassHandler if !fallback => // I am trying to guess if the import is a defined class // This is an ugly hack, I am not 100% sure of the consequences. // Here we, let everything but "defined classes" use the import with val. diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e2d8d5ff38dbe..c8763eb277052 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -256,6 +256,33 @@ class ReplSuite extends FunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + // We need to use local-cluster to test this case. + val output = runInterpreter("local-cluster[1,1,512]", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.createSchemaRDD + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter("local", + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", From cb9e7d5aff2ce9cb501a2825651224311263ce20 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 31 Jul 2014 23:12:38 -0700 Subject: [PATCH 076/170] SPARK-2738. Remove redundant imports in BlockManagerSuite Author: Sandy Ryza Closes #1642 from sryza/sandy-spark-2738 and squashes the following commits: a923e4e [Sandy Ryza] SPARK-2738. Remove redundant imports in BlockManagerSuite --- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index dd4fd535d3577..58ea0cc30e954 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,9 +21,6 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import akka.actor._ -import org.apache.spark.SparkConf -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ From 8ff4417f70198ba2d848157f9da4e1e7e18f4fca Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 1 Aug 2014 00:01:30 -0700 Subject: [PATCH 077/170] [SPARK-2670] FetchFailedException should be thrown when local fetch has failed Author: Kousuke Saruta Closes #1578 from sarutak/SPARK-2670 and squashes the following commits: 85c8938 [Kousuke Saruta] Removed useless results.put for fail fast e8713cc [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670 d353984 [Kousuke Saruta] Refined assertion messages in BlockFetcherIteratorSuite.scala 03bcb02 [Kousuke Saruta] Merge branch 'SPARK-2670' of github.com:sarutak/spark into SPARK-2670 5d05855 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670 4fca130 [Kousuke Saruta] Added test cases for BasicBlockFetcherIterator b7b8250 [Kousuke Saruta] Modified BasicBlockFetchIterator to fail fast when local fetch error has been occurred a3a9be1 [Kousuke Saruta] Modified BlockFetcherIterator for SPARK-2670 460dc01 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2670 e310c0b [Kousuke Saruta] Modified BlockFetcherIterator to handle local fetch failure as fatch fail --- .../spark/storage/BlockFetcherIterator.scala | 19 ++- .../storage/BlockFetcherIteratorSuite.scala | 140 ++++++++++++++++++ 2 files changed, 151 insertions(+), 8 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 69905a960a2ca..ccf830e118ee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -200,14 +200,17 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") + try { + // getLocalFromDisk never return None but throws BlockException + val iter = getLocalFromDisk(id, serializer).get + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } catch { + case e: Exception => { + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala new file mode 100644 index 0000000000000..8dca2ebb312f5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.{FunSuite, Matchers} +import org.scalatest.PrivateMethodTester._ + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.stubbing.Answer +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark._ +import org.apache.spark.storage.BlockFetcherIterator._ +import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, + Message} + +class BlockFetcherIteratorSuite extends FunSuite with Matchers { + + test("block fetch from local fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + val answer = new Answer[Option[Iterator[Any]]] { + override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { + throw new Exception + } + } + + // 3rd block is going to fail + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // 3rd getLocalFromDisk invocation should be failed + verify(blockManager, times(3)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully + assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + // 3rd fetch should be failed + assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator. + // Otherwise, BasicBlockFetcherIterator hangs up. + } + + + test("block fetch from local succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + doReturn(connManager).when(blockManager).connectionManager + doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId + + doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + + // All blocks should be fetched successfully + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null) + + iterator.initialize() + + // getLocalFromDis should be invoked for all of 5 blocks + verify(blockManager, times(5)).getLocalFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") + assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") + } + +} From 72e33699732496fa71e8c8b0de2203b908423fb2 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Aug 2014 00:16:18 -0700 Subject: [PATCH 078/170] SPARK-983. Support external sorting in sortByKey() This patch simply uses the ExternalSorter class from sort-based shuffle. Closes #931 and Closes #1090 Author: Matei Zaharia Closes #1677 from mateiz/spark-983 and squashes the following commits: 96b3fda [Matei Zaharia] SPARK-983. Support external sorting in sortByKey() --- .../shuffle/hash/HashShuffleReader.scala | 22 +++++++++---------- .../util/collection/ExternalSorterSuite.scala | 10 +++++++++ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e32ad9c036ad4..7c9dc8e5f88ef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -35,8 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, - Serializer.getSerializer(dep.serializer)) + val ser = Serializer.getSerializer(dep.serializer) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -54,16 +55,13 @@ private[spark] class HashShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Define a Comparator for the whole record based on the key Ordering. - val cmp = new Ordering[Product2[K, C]] { - override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = { - keyOrd.compare(o1._1, o2._1) - } - } - val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray - // TODO: do external sort. - scala.util.Sorting.quickSort(sortBuffer)(cmp) - sortBuffer.iterator + // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, + // the ExternalSorter won't spill to disk. + val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + sorter.write(aggregatedIter) + context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + sorter.iterator case None => aggregatedIter } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index ddb5df40360e9..65a71e5a83698 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -190,6 +190,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~17 times + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("spilling in local cluster with many reduce tasks") { @@ -256,6 +261,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}") } } + + // sortByKey - should spill ~8 times per executor + val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i)) + val resultE = rddE.sortByKey().collect().toSeq + assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq) } test("cleanup of intermediate files in sorter") { From f1957e11652a537efd40771f843591a4c9341014 Mon Sep 17 00:00:00 2001 From: Rahul Singhal Date: Fri, 1 Aug 2014 00:33:15 -0700 Subject: [PATCH 079/170] SPARK-2134: Report metrics before application finishes Author: Rahul Singhal Closes #1076 from rahulsinghaliitd/SPARK-2134 and squashes the following commits: 15f18b6 [Rahul Singhal] SPARK-2134: Report metrics before application finishes --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + .../main/scala/org/apache/spark/deploy/master/Master.scala | 2 ++ .../main/scala/org/apache/spark/deploy/worker/Worker.scala | 1 + .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 1 + core/src/main/scala/org/apache/spark/executor/Executor.scala | 4 ++++ .../main/scala/org/apache/spark/metrics/MetricsSystem.scala | 4 ++++ .../scala/org/apache/spark/metrics/sink/ConsoleSink.scala | 4 ++++ .../main/scala/org/apache/spark/metrics/sink/CsvSink.scala | 4 ++++ .../scala/org/apache/spark/metrics/sink/GraphiteSink.scala | 4 ++++ .../main/scala/org/apache/spark/metrics/sink/JmxSink.scala | 2 ++ .../scala/org/apache/spark/metrics/sink/MetricsServlet.scala | 2 ++ core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala | 1 + .../scala/org/apache/spark/metrics/sink/GangliaSink.scala | 4 ++++ 13 files changed, 34 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b25f081761a64..f5a0549834a0d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -990,6 +990,7 @@ class SparkContext(config: SparkConf) extends Logging { val dagSchedulerCopy = dagScheduler dagScheduler = null if (dagSchedulerCopy != null) { + env.metricsSystem.report() metadataCleaner.cancel() cleaner.foreach(_.stop()) dagSchedulerCopy.stop() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 21f8667819c44..a70ecdb375373 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -154,6 +154,8 @@ private[spark] class Master( } override def postStop() { + masterMetricsSystem.report() + applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { recoveryCompletionTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ce425443051b0..fb5252da96519 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -357,6 +357,7 @@ private[spark] class Worker( } override def postStop() { + metricsSystem.report() registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 860b47e056451..af736de405397 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -88,6 +88,7 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + executor.stop() context.stop(self) context.system.shutdown() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3b69bc4ca4142..99d650a3636e2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -121,6 +121,10 @@ private[spark] class Executor( } } + def stop(): Unit = { + env.metricsSystem.report() + } + /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 651511da1b7fe..6ef817d0e587e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -91,6 +91,10 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } + def report(): Unit = { + sinks.foreach(_.report()) + } + def registerSource(source: Source) { sources += source try { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 05852f1f98993..81b9056b40fb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -57,5 +57,9 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 542dce65366b2..9d5f2ae9328ad 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -66,5 +66,9 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index aeb4ad44a0647..d7b5f5c40efae 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -81,4 +81,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index ed27234b4e760..2588fe2c9edb8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -35,4 +35,6 @@ private[spark] class JmxSink(val property: Properties, val registry: MetricRegis reporter.stop() } + override def report() { } + } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 571539ba5e467..2f65bc8b46609 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -57,4 +57,6 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr override def start() { } override def stop() { } + + override def report() { } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 6f2b5a06027ea..0d83d8c425ca4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -20,4 +20,5 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { def start: Unit def stop: Unit + def report(): Unit } diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index d03d7774e8c80..3b1880e143513 100644 --- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -82,5 +82,9 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry, override def stop() { reporter.stop() } + + override def report() { + reporter.report() + } } From 284771efbef2d6b22212afd49dd62732a2cf52a8 Mon Sep 17 00:00:00 2001 From: Ye Xianjin Date: Fri, 1 Aug 2014 00:34:39 -0700 Subject: [PATCH 080/170] [Spark 2557] fix LOCAL_N_REGEX in createTaskScheduler and make local-n and local-n-failures consistent [SPARK-2557](https://issues.apache.org/jira/browse/SPARK-2557) Author: Ye Xianjin Closes #1464 from advancedxy/SPARK-2557 and squashes the following commits: d844d67 [Ye Xianjin] add local-*-n-failures, bad-local-n, bad-local-n-failures test case 3bbc668 [Ye Xianjin] fix LOCAL_N_REGEX regular expression and make local_n_failures accept * as all cores on the computer --- .../scala/org/apache/spark/SparkContext.scala | 10 +++++--- .../SparkContextSchedulerCreationSuite.scala | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f5a0549834a0d..0e513568b0243 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1452,9 +1452,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1484,8 +1484,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 67e3be21c3c93..4b727e50dbe67 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite } } + test("local-*-n-failures") { + val sched = createTaskScheduler("local[* ,2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n-failures") { val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite } } + test("bad-local-n") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("bad-local-n-failures") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*,4]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + test("local-default-parallelism") { val defaultParallelism = System.getProperty("spark.default.parallelism") System.setProperty("spark.default.parallelism", "16") From a32f0fb73a739c56208cafcd9f08618fb6dd8859 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 1 Aug 2014 04:32:46 -0700 Subject: [PATCH 081/170] [SPARK-2103][Streaming] Change to ClassTag for KafkaInputDStream and fix reflection issue This PR updates previous Manifest for KafkaInputDStream's Decoder to ClassTag, also fix the problem addressed in [SPARK-2103](https://issues.apache.org/jira/browse/SPARK-2103). Previous Java interface cannot actually get the type of Decoder, so when using this Manifest to reconstruct the decode object will meet reflection exception. Also for other two Java interfaces, ClassTag[String] is useless because calling Scala API will get the right implicit ClassTag. Current Kafka unit test cannot actually verify the interface. I've tested these interfaces in my local and distribute settings. Author: jerryshao Closes #1508 from jerryshao/SPARK-2103 and squashes the following commits: e90c37b [jerryshao] Add Mima excludes 7529810 [jerryshao] Change Manifest to ClassTag for KafkaInputDStream's Decoder and fix Decoder construct issue when using Java API --- .../streaming/kafka/KafkaInputDStream.scala | 14 +++++++------- .../spark/streaming/kafka/KafkaUtils.scala | 16 +++++----------- project/MimaExcludes.scala | 7 ++++++- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38095e88dcea9..e20e2c8f26991 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import scala.collection.Map -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import java.util.Properties import java.util.concurrent.Executors @@ -48,8 +48,8 @@ private[streaming] class KafkaInputDStream[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -66,8 +66,8 @@ private[streaming] class KafkaReceiver[ K: ClassTag, V: ClassTag, - U <: Decoder[_]: Manifest, - T <: Decoder[_]: Manifest]( + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel @@ -103,10 +103,10 @@ class KafkaReceiver[ tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) } - val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] - val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 86bb91f362d29..48668f763e41e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -65,7 +65,7 @@ object KafkaUtils { * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: Manifest, T <: Decoder[_]: Manifest]( + def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -89,8 +89,6 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } @@ -111,8 +109,6 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - implicit val cmt: ClassTag[String] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -140,13 +136,11 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[K, V] = { - implicit val keyCmt: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val valueCmt: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + implicit val keyCmt: ClassTag[K] = ClassTag(keyTypeClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueTypeClass) - implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]] - implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]] + implicit val keyCmd: ClassTag[U] = ClassTag(keyDecoderClass) + implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5a835f58207cf..537ca0dcf267d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -71,7 +71,12 @@ object MimaExcludes { "org.apache.spark.storage.TachyonStore.putValues") ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") ) ++ Seq( // Ignore some private methods in ALS. ProblemFilters.exclude[MissingMethodProblem]( From 82d209d43fb543c174e640667de15b00c7fb5d35 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 1 Aug 2014 07:32:53 -0700 Subject: [PATCH 082/170] SPARK-2768 [MLLIB] Add product, user recommend method to MatrixFactorizationModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Right now, `MatrixFactorizationModel` can only predict a score for one or more `(user,product)` tuples. As a comment in the file notes, it would be more useful to expose a recommend method, that computes top N scoring products for a user (or vice versa – users for a product). (This also corrects some long lines in the Java ALS test suite.) As you can see, it's a little messy to access the class from Java. Should there be a Java-friendly wrapper for it? with a pointer about where that should go, I could add that. Author: Sean Owen Closes #1687 from srowen/SPARK-2768 and squashes the following commits: b349675 [Sean Owen] Additional review changes c9edb04 [Sean Owen] Updates from code review 7bc35f9 [Sean Owen] Add recommend methods to MatrixFactorizationModel --- .../MatrixFactorizationModel.scala | 44 ++++++++++- .../mllib/recommendation/JavaALSSuite.java | 75 ++++++++++++++----- 2 files changed, 100 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 899286d235a9d..a1a76fcbe9f9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -65,6 +65,48 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Recommends products to a user. + * + * @param user the user to recommend products to + * @param num how many products to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains the given user ID, a product ID, and a + * "score" in the rating field. Each represents one recommended product, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the user. The score is an opaque value that indicates how strongly + * recommended the product is. + */ + def recommendProducts(user: Int, num: Int): Array[Rating] = + recommend(userFeatures.lookup(user).head, productFeatures, num) + .map(t => Rating(user, t._1, t._2)) + + /** + * Recommends users to a product. That is, this returns users who are most likely to be + * interested in a product. + * + * @param product the product to recommend users to + * @param num how many users to return. The number returned may be less than this. + * @return [[Rating]] objects, each of which contains a user ID, the given product ID, and a + * "score" in the rating field. Each represents one recommended user, and they are sorted + * by score, decreasing. The first returned is the one predicted to be most strongly + * recommended to the product. The score is an opaque value that indicates how strongly + * recommended the user is. + */ + def recommendUsers(product: Int, num: Int): Array[Rating] = + recommend(productFeatures.lookup(product).head, userFeatures, num) + .map(t => Rating(t._1, product, t._2)) + + private def recommend( + recommendToFeatures: Array[Double], + recommendableFeatures: RDD[(Int, Array[Double])], + num: Int): Array[(Int, Double)] = { + val recommendToVector = new DoubleMatrix(recommendToFeatures) + val scored = recommendableFeatures.map { case (id,features) => + (id, recommendToVector.dot(new DoubleMatrix(features))) + } + scored.top(num)(Ordering.by(_._2)) + } + /** * :: DeveloperApi :: * Predict the rating of many users for many products. @@ -80,6 +122,4 @@ class MatrixFactorizationModel private[mllib] ( predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) } - // TODO: Figure out what other good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index bf2365f82044c..f6ca9643227f8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -20,6 +20,11 @@ import java.io.Serializable; import java.util.List; +import scala.Tuple2; +import scala.Tuple3; + +import org.jblas.DoubleMatrix; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -28,8 +33,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.jblas.DoubleMatrix; - public class JavaALSSuite implements Serializable { private transient JavaSparkContext sc; @@ -44,21 +47,28 @@ public void tearDown() { sc = null; } - static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { + static void validatePrediction( + MatrixFactorizationModel model, + int users, + int products, + int features, + DoubleMatrix trueRatings, + double matchThreshold, + boolean implicitPrefs, + DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { + for (Tuple2 userFeature : userFeatures) { predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); } } DoubleMatrix predictedP = new DoubleMatrix(products, features); - List> productFeatures = + List> productFeatures = model.productFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { + for (Tuple2 productFeature : productFeatures) { predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); } } @@ -75,7 +85,8 @@ static void validatePrediction(MatrixFactorizationModel model, int users, int pr } } } else { - // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + // For implicit prefs we use the confidence-weighted RMSE to test + // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; for (int u = 0; u < users; ++u) { @@ -100,7 +111,7 @@ public void runALSUsingStaticMethods() { int iterations = 15; int users = 50; int products = 100; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -114,14 +125,14 @@ public void runALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); + .setIterations(iterations) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @@ -131,7 +142,7 @@ public void runImplicitALSUsingStaticMethods() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -145,7 +156,7 @@ public void runImplicitALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -163,12 +174,42 @@ public void runImplicitALSWithNegativeWeight() { int iterations = 15; int users = 80; int products = 160; - scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( users, products, features, 0.7, true, true); JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + @Test + public void runRecommend() { + int features = 5; + int iterations = 10; + int users = 200; + int products = 50; + Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, false); + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .setSeed(8675309L) + .run(data.rdd()); + validateRecommendations(model.recommendProducts(1, 10), 10); + validateRecommendations(model.recommendUsers(1, 20), 20); + } + + private static void validateRecommendations(Rating[] recommendations, int howMany) { + Assert.assertEquals(howMany, recommendations.length); + for (int i = 1; i < recommendations.length; i++) { + Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + } + Assert.assertTrue(recommendations[0].rating() > 0.7); + } + } From 0dacb1adb5e6118bd218537bee71926344cd9fb0 Mon Sep 17 00:00:00 2001 From: witgo Date: Fri, 1 Aug 2014 07:47:44 -0700 Subject: [PATCH 083/170] [SPARK-1997] update breeze to version 0.8.1 `breeze 0.8.1` dependent on `scala-logging-slf4j 2.1.1` The relevant code on #1369 Author: witgo Closes #940 from witgo/breeze-8.0.1 and squashes the following commits: 65cc65e [witgo] update breeze to version 0.8.1 --- mllib/pom.xml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index cb0fa7b97cb15..9e15ca6ab836c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -52,7 +52,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.7 + 0.8.1 @@ -60,6 +60,10 @@ junit junit + + org.apache.commons + commons-math3 + From 5328c0aaa09911c848f9b3e1e1f2397bef932d0f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 1 Aug 2014 10:00:46 -0700 Subject: [PATCH 084/170] [HOTFIX] downgrade breeze version to 0.7 breeze-0.8.1 causes dependency issues, as discussed in #940 . Author: Xiangrui Meng Closes #1718 from mengxr/revert-breeze and squashes the following commits: 99c4681 [Xiangrui Meng] downgrade breeze version to 0.7 --- mllib/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index 9e15ca6ab836c..45046eca5b18c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -52,7 +52,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.8.1 + 0.7 From 8d338f64c4eda45d22ae33f61ef7928011cc2846 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 1 Aug 2014 11:08:39 -0700 Subject: [PATCH 085/170] SPARK-2099. Report progress while task is running. This is a sketch of a patch that allows the UI to show metrics for tasks that have not yet completed. It adds a heartbeat every 2 seconds from the executors to the driver, reporting metrics for all of the executor's tasks. It still needs unit tests, polish, and cluster testing, but I wanted to put it up to get feedback on the approach. Author: Sandy Ryza Closes #1056 from sryza/sandy-spark-2099 and squashes the following commits: 93b9fdb [Sandy Ryza] Up heartbeat interval to 10 seconds and other tidying 132aec7 [Sandy Ryza] Heartbeat and HeartbeatResponse are already Serializable as case classes 38dffde [Sandy Ryza] Additional review feedback and restore test that was removed in BlockManagerSuite 51fa396 [Sandy Ryza] Remove hostname race, add better comments about threading, and some stylistic improvements 3084f10 [Sandy Ryza] Make TaskUIData a case class again 3bda974 [Sandy Ryza] Stylistic fixes 0dae734 [Sandy Ryza] SPARK-2099. Report progress while task is running. --- .../org/apache/spark/HeartbeatReceiver.scala | 46 +++++++ .../scala/org/apache/spark/SparkContext.scala | 4 + .../scala/org/apache/spark/SparkEnv.scala | 8 +- .../org/apache/spark/executor/Executor.scala | 55 +++++++- .../apache/spark/executor/TaskMetrics.scala | 10 +- .../apache/spark/scheduler/DAGScheduler.scala | 21 +++- .../spark/scheduler/SparkListener.scala | 11 ++ .../spark/scheduler/SparkListenerBus.scala | 2 + .../org/apache/spark/scheduler/Task.scala | 3 + .../spark/scheduler/TaskScheduler.scala | 10 ++ .../spark/scheduler/TaskSchedulerImpl.scala | 23 ++++ .../spark/scheduler/local/LocalBackend.scala | 9 +- .../apache/spark/storage/BlockManager.scala | 25 +--- .../spark/storage/BlockManagerMaster.scala | 43 +------ .../storage/BlockManagerMasterActor.scala | 29 +++-- .../spark/storage/BlockManagerMessages.scala | 6 +- .../spark/ui/jobs/JobProgressListener.scala | 117 +++++++++++------- .../org/apache/spark/ui/jobs/UIData.scala | 9 +- .../org/apache/spark/util/AkkaUtils.scala | 66 +++++++++- .../SparkContextSchedulerCreationSuite.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 5 + .../spark/storage/BlockManagerSuite.scala | 23 ++-- .../ui/jobs/JobProgressListenerSuite.scala | 86 ++++++++++++- docs/configuration.md | 7 ++ 24 files changed, 467 insertions(+), 157 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala new file mode 100644 index 0000000000000..24ccce21b62ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import akka.actor.Actor +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.TaskScheduler + +/** + * A heartbeat from executors to the driver. This is a shared message used by several internal + * components to convey liveness or execution information for in-progress tasks. + */ +private[spark] case class Heartbeat( + executorId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId) + +private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) + +/** + * Lives in the driver to receive heartbeats from executors.. + */ +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { + override def receive = { + case Heartbeat(executorId, taskMetrics, blockManagerId) => + val response = HeartbeatResponse( + !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) + sender ! response + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e513568b0243..5f75c1dd2cb68 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast @@ -307,6 +308,8 @@ class SparkContext(config: SparkConf) extends Logging { // Create and start the scheduler private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private val heartbeatReceiver = env.actorSystem.actorOf( + Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -992,6 +995,7 @@ class SparkContext(config: SparkConf) extends Logging { if (dagSchedulerCopy != null) { env.metricsSystem.report() metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6ee731b22c03c..92c809d854167 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -193,13 +193,7 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) - logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + AkkaUtils.makeDriverRef(name, conf, actorSystem) } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 99d650a3636e2..1bb1b4aae91bb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import java.util.concurrent._ import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.scheduler._ @@ -48,6 +48,8 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + @volatile private var isStopped = false + // No ip or host:port - just hostname Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -107,6 +109,8 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + startDriverHeartbeater() + def launchTask( context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { val tr = new TaskRunner(context, taskId, taskName, serializedTask) @@ -121,8 +125,10 @@ private[spark] class Executor( } } - def stop(): Unit = { + def stop() { env.metricsSystem.report() + isStopped = true + threadPool.shutdown() } /** Get the Yarn approved local directories. */ @@ -141,11 +147,12 @@ private[spark] class Executor( } class TaskRunner( - execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false - @volatile private var task: Task[Any] = _ + @volatile var task: Task[Any] = _ + @volatile var attemptedTask: Option[Task[Any]] = None def kill(interruptThread: Boolean) { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") @@ -162,7 +169,6 @@ private[spark] class Executor( val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum val startGCTime = gcTime @@ -204,7 +210,6 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.hostname = Utils.localHostName() m.executorDeserializeTime = taskStart - startTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime @@ -354,4 +359,42 @@ private[spark] class Executor( } } } + + def startDriverHeartbeater() { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + val t = new Thread() { + override def run() { + // Sleep a random interval so the heartbeats don't end up in sync + Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + + while (!isStopped) { + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + for (taskRunner <- runningTasks.values()) { + if (!taskRunner.attemptedTask.isEmpty) { + Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + Thread.sleep(interval) + } + } + } + t.setDaemon(true) + t.setName("Driver Heartbeater") + t.start() + } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 21fe643b8d71f..56cd8723a3a22 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -23,6 +23,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. + * + * This class is used to house metrics both for in-progress and completed tasks. In executors, + * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread + * reads it to send in-progress metrics, and the task thread reads it to send metrics along with + * the completed task. + * + * So, when adding new fields, take into consideration that the whole object can be serialized for + * shipping off at any time to consumers of the SparkListener interface. */ @DeveloperApi class TaskMetrics extends Serializable { @@ -143,7 +151,7 @@ class ShuffleReadMetrics extends Serializable { /** * Absolute time when this task finished reading shuffle data */ - var shuffleFinishTime: Long = _ + var shuffleFinishTime: Long = -1 /** * Number of blocks fetched in this shuffle by this task (remote or local) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 50186d097a632..c7e3d7c5f8530 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -29,7 +29,6 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor._ -import akka.actor.OneForOneStrategy import akka.actor.SupervisorStrategy.Stop import akka.pattern.ask import akka.util.Timeout @@ -39,8 +38,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -154,6 +154,23 @@ class DAGScheduler( eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics) + blockManagerId: BlockManagerId): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + implicit val timeout = Timeout(600 seconds) + + Await.result( + blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), + timeout.duration).asInstanceOf[Boolean] + } + // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { eventProcessActor ! ExecutorLost(execId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 82163eadd56e9..d01d318633877 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -75,6 +75,12 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorMetricsUpdate( + execId: String, + taskMetrics: Seq[(Long, Int, TaskMetrics)]) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) extends SparkListenerEvent @@ -158,6 +164,11 @@ trait SparkListener { * Called when the application ends */ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + + /** + * Called when the driver receives task metrics from an executor in a heartbeat. + */ + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index ed9fb24bc8ce8..e79ffd7a3587d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -68,6 +68,8 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationStart(applicationStart)) case applicationEnd: SparkListenerApplicationEnd => foreachListener(_.onApplicationEnd(applicationEnd)) + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5871edeb856ad..5c5e421404a21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,8 @@ import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils + /** * A unit of execution. We have two kinds of Task's in Spark: @@ -44,6 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context.taskMetrics.hostname = Utils.localHostName(); taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 819c35257b5a7..1a0b877c8a5e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -18,6 +18,8 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. @@ -54,4 +56,12 @@ private[spark] trait TaskScheduler { // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index be3673c48eda8..d2f764fc22f54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -32,6 +32,9 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -320,6 +323,26 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + blockManagerId: BlockManagerId): Boolean = { + val metricsWithStageIds = taskMetrics.flatMap { + case (id, metrics) => { + taskIdToTaskSetId.get(id) + .flatMap(activeTaskSets.get) + .map(_.stageId) + .map(x => (id, x, metrics)) + } + } + dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + } + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { taskSetManager.handleTaskGettingResult(tid) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 5b897597fa285..3d1cf312ccc97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,8 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.storage.BlockManagerId private case class ReviveOffers() @@ -32,6 +33,8 @@ private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class StopExecutor() + /** * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend @@ -63,6 +66,9 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + + case StopExecutor => + executor.stop() } def reviveOffers() { @@ -91,6 +97,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { + localActor ! StopExecutor } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d746526639e58..c0a06017945f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -116,15 +116,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private def heartBeat(): Unit = { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - private val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - private var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner( MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) private val broadcastCleaner = new MetadataCleaner( @@ -161,11 +152,6 @@ private[spark] class BlockManager( private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - Utils.tryOrExit { heartBeat() } - } - } } /** @@ -195,7 +181,7 @@ private[spark] class BlockManager( * * Note that this method must be called without any BlockInfo locks held. */ - private def reregister(): Unit = { + def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveActor) @@ -1065,9 +1051,6 @@ private[spark] class BlockManager( } def stop(): Unit = { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() @@ -1095,12 +1078,6 @@ private[spark] object BlockManager extends Logging { (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } - def getHeartBeatFrequency(conf: SparkConf): Long = - conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 - - def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7897fade2df2b..669307765d1fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,7 +21,6 @@ import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ -import akka.pattern.ask import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -29,8 +28,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) - val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) + private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) + private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -42,15 +41,6 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log logInfo("Removed " + execId + " successfully in removeExecutor") } - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - /** Register the BlockManager's id with the driver. */ def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -223,33 +213,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) + AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, + timeout) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index de1cc5539fb48..94f5a4bb2e9cd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,25 +52,24 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.get("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong + val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", + math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.get("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", + 60000) var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } + import context.dispatcher + timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, + checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) super.preStart() } def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -129,8 +128,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case ExpireDeadHosts => expireDeadHosts() - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) + case BlockManagerHeartbeat(blockManagerId) => + sender ! heartbeatReceived(blockManagerId) case other => logWarning("Got unknown message: " + other) @@ -216,7 +215,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { + if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -230,7 +229,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { blockManagerId.executorId == "" && !isLocal } else { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2b53bf33b5fba..10b65286fb7db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef -private[storage] object BlockManagerMessages { +private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// @@ -53,8 +53,6 @@ private[storage] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, @@ -124,5 +122,7 @@ private[storage] object BlockManagerMessages { case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) extends ToBlockManagerMaster + case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index efb527b4f03e6..da2f5d3172fe2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -130,32 +130,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) - // create executor summary map if necessary - val executorSummaryMap = stageData.executorSummary - executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - - executorSummaryMap.get(info.executorId).foreach { y => - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } - - // update duration - y.taskTime += info.duration - - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + val execSummaryMap = stageData.executorSummary + val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) + + taskEnd.reason match { + case Success => + execSummary.succeededTasks += 1 + case _ => + execSummary.failedTasks += 1 } - + execSummary.taskTime += info.duration stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = @@ -171,28 +155,75 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { (Some(e.toErrorString), None) } + if (!metrics.isEmpty) { + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + } - val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) - stageData.executorRunTime += taskRunTime - val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageData.inputBytes += inputBytes - - val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageData.shuffleReadBytes += shuffleRead - - val shuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageData.shuffleWriteBytes += shuffleWrite - - val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageData.memoryBytesSpilled += memoryBytesSpilled + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) + taskData.taskInfo = info + taskData.taskMetrics = metrics + taskData.errorMessage = errorMessage + } + } - val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageData.diskBytesSpilled += diskBytesSpilled + /** + * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage + * aggregate metrics by calculating deltas between the currently recorded metrics and the new + * metrics. + */ + def updateAggregateMetrics( + stageData: StageUIData, + execId: String, + taskMetrics: TaskMetrics, + oldMetrics: Option[TaskMetrics]) { + val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) + + val shuffleWriteDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + stageData.shuffleWriteBytes += shuffleWriteDelta + execSummary.shuffleWrite += shuffleWriteDelta + + val shuffleReadDelta = + (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) + stageData.shuffleReadBytes += shuffleReadDelta + execSummary.shuffleRead += shuffleReadDelta + + val diskSpillDelta = + taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) + stageData.diskBytesSpilled += diskSpillDelta + execSummary.diskBytesSpilled += diskSpillDelta + + val memorySpillDelta = + taskMetrics.memoryBytesSpilled - oldMetrics.map(_.memoryBytesSpilled).getOrElse(0L) + stageData.memoryBytesSpilled += memorySpillDelta + execSummary.memoryBytesSpilled += memorySpillDelta + + val timeDelta = + taskMetrics.executorRunTime - oldMetrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += timeDelta + } - stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) + override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { + for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + val stageData = stageIdToData.getOrElseUpdate(sid, { + logWarning("Metrics update for task in unknown stage " + sid) + new StageUIData + }) + val taskData = stageData.taskData.get(taskId) + taskData.map { t => + if (!t.taskInfo.finished) { + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, + t.taskMetrics) + + // Overwrite task metrics + t.taskMetrics = Some(taskMetrics) + } + } } - } // end of onTaskEnd + } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index be11a11695b01..2f96f7909c199 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -55,8 +55,11 @@ private[jobs] object UIData { var executorSummary = new HashMap[String, ExecutorSummary] } + /** + * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. + */ case class TaskUIData( - taskInfo: TaskInfo, - taskMetrics: Option[TaskMetrics] = None, - errorMessage: Option[String] = None) + var taskInfo: TaskInfo, + var taskMetrics: Option[TaskMetrics] = None, + var errorMessage: Option[String] = None) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 9930c717492f2..feafd654e9e71 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,13 +18,16 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap +import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{ActorSystem, ExtendedActorSystem} +import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem} +import akka.pattern.ask + import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -124,4 +127,63 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.akka.num.retries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Int = { + conf.getInt("spark.akka.retry.wait", 3000) + } + + /** + * Send a message to the given actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + def askWithReply[T]( + message: Any, + actor: ActorRef, + retryAttempts: Int, + retryInterval: Int, + timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + if (actor == null) { + throw new SparkException("Error sending message as driverActor is null " + + "[message = " + message + "]") + } + var attempts = 0 + var lastException: Exception = null + while (attempts < retryAttempts) { + attempts += 1 + try { + val future = actor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message in " + attempts + " attempts", e) + } + Thread.sleep(retryInterval) + } + + throw new SparkException( + "Error sending message [message = " + message + "]", lastException) + } + + def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 4b727e50dbe67..495a0d48633a4 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester} import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,12 +25,12 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach { def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - sc = new SparkContext("local", "test") + val sc = new SparkContext("local", "test") val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) sched.asInstanceOf[TaskSchedulerImpl] diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9021662bcf712..0ce13d015df05 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite +import org.apache.spark.executor.TaskMetrics class BuggyDAGEventProcessActor extends Actor { val state = 0 @@ -77,6 +78,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -342,6 +345,8 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true } val noKillScheduler = new DAGScheduler( sc, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 58ea0cc30e954..0ac0269d7cfc1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,22 +19,28 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import java.util.concurrent.TimeUnit import akka.actor._ +import akka.pattern.ask +import akka.util.Timeout + import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers -import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps @@ -73,7 +79,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter oldArch = System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.storage.disableBlockManagerHeartBeat", "true") conf.set("spark.driver.port", boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -341,7 +346,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration on heart beat") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) @@ -353,13 +357,15 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + implicit val timeout = Timeout(30, TimeUnit.SECONDS) + val reregister = !Await.result( + master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), + timeout.duration).asInstanceOf[Boolean] + assert(reregister == true) } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, - securityMgr, mapOutputTracker) + store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -377,7 +383,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } test("reregistration doesn't dead lock") { - val heartBeat = PrivateMethod[Unit]('heartBeat) store = makeBlockManager(2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -397,7 +402,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter } val t3 = new Thread { override def run() { - store invokePrivate heartBeat() + store.reregister() } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 86a271eb67000..cb8252515238e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -21,7 +21,8 @@ import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, Success} +import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -129,4 +130,87 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } + + test("test update metrics") { + val conf = new SparkConf() + val listener = new JobProgressListener(conf) + + val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) + val execId = "exe-1" + + def makeTaskMetrics(base: Int) = { + val taskMetrics = new TaskMetrics() + val shuffleReadMetrics = new ShuffleReadMetrics() + val shuffleWriteMetrics = new ShuffleWriteMetrics() + taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + shuffleReadMetrics.remoteBytesRead = base + 1 + shuffleReadMetrics.remoteBlocksFetched = base + 2 + shuffleWriteMetrics.shuffleBytesWritten = base + 3 + taskMetrics.executorRunTime = base + 4 + taskMetrics.diskBytesSpilled = base + 5 + taskMetrics.memoryBytesSpilled = base + 6 + taskMetrics + } + + def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, + false) + taskInfo.finishTime = finishTime + taskInfo + } + + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L))) + listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L))) + listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( + (1234L, 0, makeTaskMetrics(0)), + (1235L, 0, makeTaskMetrics(100)), + (1236L, 1, makeTaskMetrics(200))))) + + var stage0Data = listener.stageIdToData.get(0).get + var stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 102) + assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleWriteBytes == 106) + assert(stage1Data.shuffleWriteBytes == 203) + assert(stage0Data.executorRunTime == 108) + assert(stage1Data.executorRunTime == 204) + assert(stage0Data.diskBytesSpilled == 110) + assert(stage1Data.diskBytesSpilled == 205) + assert(stage0Data.memoryBytesSpilled == 112) + assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 2) + assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 102) + assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 202) + + // task that was included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1), + makeTaskMetrics(300))) + // task that wasn't included in a heartbeat + listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1), + makeTaskMetrics(400))) + + stage0Data = listener.stageIdToData.get(0).get + stage1Data = listener.stageIdToData.get(1).get + assert(stage0Data.shuffleReadBytes == 402) + assert(stage1Data.shuffleReadBytes == 602) + assert(stage0Data.shuffleWriteBytes == 406) + assert(stage1Data.shuffleWriteBytes == 606) + assert(stage0Data.executorRunTime == 408) + assert(stage1Data.executorRunTime == 608) + assert(stage0Data.diskBytesSpilled == 410) + assert(stage1Data.diskBytesSpilled == 610) + assert(stage0Data.memoryBytesSpilled == 412) + assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 302) + assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + .totalBlocksFetched == 402) + } } diff --git a/docs/configuration.md b/docs/configuration.md index ea69057b5be10..2a71d7b820e5f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -541,6 +541,13 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + + + +
Indicates whether the history server should use kerberos to login. This is useful if the history server is accessing HDFS files on a secure Hadoop cluster. If this is - true it looks uses the configs spark.history.kerberos.principal and + true, it uses the configs spark.history.kerberos.principal and spark.history.kerberos.keytab.
spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks.
#### Networking From c41fdf04f4beebe36379396b0c4fff3ab7ad3cf4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 1 Aug 2014 11:14:53 -0700 Subject: [PATCH 086/170] [SPARK-2179][SQL] A minor refactoring Java data type APIs (2179 follow-up). It is a follow-up PR of SPARK-2179 (https://issues.apache.org/jira/browse/SPARK-2179). It makes package names of data type APIs more consistent across languages (Scala: `org.apache.spark.sql`, Java: `org.apache.spark.sql.api.java`, Python: `pyspark.sql`). Author: Yin Huai Closes #1712 from yhuai/javaDataType and squashes the following commits: 62eb705 [Yin Huai] Move package-info. add4bcb [Yin Huai] Make the package names of data type classes consistent across languages by moving all Java data type classes to package sql.api.java. --- .../sql/api/java/{types => }/ArrayType.java | 6 +- .../sql/api/java/{types => }/BinaryType.java | 2 +- .../sql/api/java/{types => }/BooleanType.java | 2 +- .../sql/api/java/{types => }/ByteType.java | 2 +- .../sql/api/java/{types => }/DataType.java | 2 +- .../sql/api/java/{types => }/DecimalType.java | 2 +- .../sql/api/java/{types => }/DoubleType.java | 2 +- .../sql/api/java/{types => }/FloatType.java | 2 +- .../sql/api/java/{types => }/IntegerType.java | 2 +- .../sql/api/java/{types => }/LongType.java | 2 +- .../sql/api/java/{types => }/MapType.java | 6 +- .../sql/api/java/{types => }/ShortType.java | 2 +- .../sql/api/java/{types => }/StringType.java | 2 +- .../sql/api/java/{types => }/StructField.java | 4 +- .../sql/api/java/{types => }/StructType.java | 7 +-- .../api/java/{types => }/TimestampType.java | 2 +- .../spark/sql/api/java}/package-info.java | 2 +- .../sql/api/java/types/package-info.java | 22 ------- .../spark/sql/api/java/JavaSQLContext.scala | 60 ++++++++++++------- .../spark/sql/api/java/JavaSchemaRDD.scala | 1 - .../sql/types/util/DataTypeConversions.scala | 30 +++++----- .../sql/api/java/JavaApplySchemaSuite.java | 3 - .../java/JavaSideDataTypeConversionSuite.java | 2 - .../ScalaSideDataTypeConversionSuite.scala | 59 +++++++++--------- 24 files changed, 108 insertions(+), 118 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/ArrayType.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/BinaryType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/BooleanType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/ByteType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/DataType.java (99%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/DecimalType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/DoubleType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/FloatType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/IntegerType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/LongType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/MapType.java (91%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/ShortType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/StringType.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/StructField.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/StructType.java (86%) rename sql/core/src/main/java/org/apache/spark/sql/api/java/{types => }/TimestampType.java (95%) rename sql/core/src/main/{scala/org/apache/spark/sql => java/org/apache/spark/sql/api/java}/package-info.java (95%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java index 17334ca31b2b7..b73a371e93001 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing Lists. @@ -25,8 +25,8 @@ * {@code null} values. * * To create an {@link ArrayType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType, boolean)} + * {@link DataType#createArrayType(DataType)} or + * {@link DataType#createArrayType(DataType, boolean)} * should be used. */ public class ArrayType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java index 61703179850e9..7daad60f62a0b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing byte[] values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java index 8fa24d85d1238..5a1f52725631b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing boolean and Boolean values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java index 2de32978e2705..e5cdf06b21bbe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing byte and Byte values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java similarity index 99% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index f84e5a490a905..3eccddef88134 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; import java.util.HashSet; import java.util.List; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index 9250491a2d2ca..bc54c078d7a4e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing java.math.BigDecimal values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java index 3e86917fddc4b..f0060d0bcf9f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing double and Double values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java index fa860d40176ef..4a6a37f69176a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing float and Float values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java index bd973eca2c3ce..bfd70490bbbbb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing int and Integer values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java index e00233304cefa..af13a46eb165c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing long and Long values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java index 94936e2e4ee7a..063e6b34abc48 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing Maps. A MapType object comprises two fields, @@ -27,8 +27,8 @@ * For values of a MapType column, keys are not allowed to have {@code null} values. * * To create a {@link MapType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType, boolean)} + * {@link DataType#createMapType(DataType, DataType)} or + * {@link DataType#createMapType(DataType, DataType, boolean)} * should be used. */ public class MapType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java index 98f9507acf121..7d7604b4e3d2d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing short and Short values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java index b8e7dbe646071..f4ba0c07c9c6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing String values. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java index 54e9c11ea415e..b48e2a2c5f953 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * A StructField object represents a field in a StructType object. @@ -26,7 +26,7 @@ * values. * * To create a {@link StructField}, - * {@link org.apache.spark.sql.api.java.types.DataType#createStructField(String, DataType, boolean)} + * {@link DataType#createStructField(String, DataType, boolean)} * should be used. */ public class StructField { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java similarity index 86% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java index 33a42f4b16265..a4b501efd9a10 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; import java.util.Arrays; -import java.util.List; /** * The data type representing Rows. * A StructType object comprises an array of StructFields. * * To create an {@link StructType}, - * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(java.util.List)} or - * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(StructField[])} + * {@link DataType#createStructType(java.util.List)} or + * {@link DataType#createStructType(StructField[])} * should be used. */ public class StructType extends DataType { diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java index 65295779f71ec..06d44c731cdfe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java.types; +package org.apache.spark.sql.api.java; /** * The data type representing java.sql.Timestamp values. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java index 53603614518f5..67007a9f0d1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package-info.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/package-info.java @@ -18,4 +18,4 @@ /** * Allows the execution of relational queries, including those expressed in SQL using Spark. */ -package org.apache.spark.sql; \ No newline at end of file +package org.apache.spark.sql.api.java; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java deleted file mode 100644 index f169ac65e226f..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/** - * Allows users to get and create Spark SQL data types. - */ -package org.apache.spark.sql.api.java.types; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c1c18a0cd0ed6..809dd038f94aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,9 +23,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.api.java.types.{StructType => JStructType} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql._ +import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} @@ -104,9 +103,9 @@ class JavaSQLContext(val sqlContext: SQLContext) { * provided schema. Otherwise, there will be runtime exception. */ @DeveloperApi - def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = { + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): JavaSchemaRDD = { val scalaRowRDD = rowRDD.rdd.map(r => r.row) - val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType] + val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -133,7 +132,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * returning the result as a JavaSchemaRDD. */ @Experimental - def jsonFile(path: String, schema: JStructType): JavaSchemaRDD = + def jsonFile(path: String, schema: StructType): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path), schema) /** @@ -155,10 +154,10 @@ class JavaSQLContext(val sqlContext: SQLContext) { * returning the result as a JavaSchemaRDD. */ @Experimental - def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = { + def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType] + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) @@ -181,22 +180,37 @@ class JavaSQLContext(val sqlContext: SQLContext) { val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.lang.String] => + (org.apache.spark.sql.StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => + (org.apache.spark.sql.ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => + (org.apache.spark.sql.IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => + (org.apache.spark.sql.LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => + (org.apache.spark.sql.DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => + (org.apache.spark.sql.ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => + (org.apache.spark.sql.FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => + (org.apache.spark.sql.BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => + (org.apache.spark.sql.ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => + (org.apache.spark.sql.IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => + (org.apache.spark.sql.LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => + (org.apache.spark.sql.DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => + (org.apache.spark.sql.ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => + (org.apache.spark.sql.FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => + (org.apache.spark.sql.BooleanType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 824574149858c..4d799b4038fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,7 +22,6 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.sql.api.java.types.StructType import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index d1aa3c8d53757..77353f4eb0227 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} import scala.collection.JavaConverters._ @@ -74,37 +74,37 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { - case stringType: org.apache.spark.sql.api.java.types.StringType => + case stringType: org.apache.spark.sql.api.java.StringType => StringType - case binaryType: org.apache.spark.sql.api.java.types.BinaryType => + case binaryType: org.apache.spark.sql.api.java.BinaryType => BinaryType - case booleanType: org.apache.spark.sql.api.java.types.BooleanType => + case booleanType: org.apache.spark.sql.api.java.BooleanType => BooleanType - case timestampType: org.apache.spark.sql.api.java.types.TimestampType => + case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType - case decimalType: org.apache.spark.sql.api.java.types.DecimalType => + case decimalType: org.apache.spark.sql.api.java.DecimalType => DecimalType - case doubleType: org.apache.spark.sql.api.java.types.DoubleType => + case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType - case floatType: org.apache.spark.sql.api.java.types.FloatType => + case floatType: org.apache.spark.sql.api.java.FloatType => FloatType - case byteType: org.apache.spark.sql.api.java.types.ByteType => + case byteType: org.apache.spark.sql.api.java.ByteType => ByteType - case integerType: org.apache.spark.sql.api.java.types.IntegerType => + case integerType: org.apache.spark.sql.api.java.IntegerType => IntegerType - case longType: org.apache.spark.sql.api.java.types.LongType => + case longType: org.apache.spark.sql.api.java.LongType => LongType - case shortType: org.apache.spark.sql.api.java.types.ShortType => + case shortType: org.apache.spark.sql.api.java.ShortType => ShortType - case arrayType: org.apache.spark.sql.api.java.types.ArrayType => + case arrayType: org.apache.spark.sql.api.java.ArrayType => ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) - case mapType: org.apache.spark.sql.api.java.types.MapType => + case mapType: org.apache.spark.sql.api.java.MapType => MapType( asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType), mapType.isValueContainsNull) - case structType: org.apache.spark.sql.api.java.types.StructType => + case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 8ee4591105010..3c92906d82864 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -28,9 +28,6 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.sql.api.java.types.DataType; -import org.apache.spark.sql.api.java.types.StructField; -import org.apache.spark.sql.api.java.types.StructType; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index 96a503962f7d1..d099a48a1f4b6 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import org.apache.spark.sql.types.util.DataTypeConversions; -import org.apache.spark.sql.api.java.types.DataType; -import org.apache.spark.sql.api.java.types.StructField; public class JavaSideDataTypeConversionSuite { public void checkDataType(DataType javaDataType) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 46de6fe239228..ff1debff0f8c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.api.java import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite -import org.apache.spark.sql._ +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} +import org.apache.spark.sql.{StructType => SStructType} import DataTypeConversions._ class ScalaSideDataTypeConversionSuite extends FunSuite { - def checkDataType(scalaDataType: DataType) { + def checkDataType(scalaDataType: SDataType) { val javaDataType = asJavaDataType(scalaDataType) val actual = asScalaDataType(javaDataType) assert(scalaDataType === actual, s"Converted data type ${actual} " + @@ -34,48 +35,52 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { test("convert data types") { // Simple DataTypes. - checkDataType(StringType) - checkDataType(BinaryType) - checkDataType(BooleanType) - checkDataType(TimestampType) - checkDataType(DecimalType) - checkDataType(DoubleType) - checkDataType(FloatType) - checkDataType(ByteType) - checkDataType(IntegerType) - checkDataType(LongType) - checkDataType(ShortType) + checkDataType(org.apache.spark.sql.StringType) + checkDataType(org.apache.spark.sql.BinaryType) + checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.TimestampType) + checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DoubleType) + checkDataType(org.apache.spark.sql.FloatType) + checkDataType(org.apache.spark.sql.ByteType) + checkDataType(org.apache.spark.sql.IntegerType) + checkDataType(org.apache.spark.sql.LongType) + checkDataType(org.apache.spark.sql.ShortType) // Simple ArrayType. - val simpleScalaArrayType = ArrayType(StringType, true) + val simpleScalaArrayType = + org.apache.spark.sql.ArrayType(org.apache.spark.sql.StringType, true) checkDataType(simpleScalaArrayType) // Simple MapType. - val simpleScalaMapType = MapType(StringType, LongType) + val simpleScalaMapType = + org.apache.spark.sql.MapType(org.apache.spark.sql.StringType, org.apache.spark.sql.LongType) checkDataType(simpleScalaMapType) // Simple StructType. - val simpleScalaStructType = StructType( - StructField("a", DecimalType, false) :: - StructField("b", BooleanType, true) :: - StructField("c", LongType, true) :: - StructField("d", BinaryType, false) :: Nil) + val simpleScalaStructType = SStructType( + SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("b", org.apache.spark.sql.BooleanType, true) :: + SStructField("c", org.apache.spark.sql.LongType, true) :: + SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) checkDataType(simpleScalaStructType) // Complex StructType. - val complexScalaStructType = StructType( - StructField("simpleArray", simpleScalaArrayType, true) :: - StructField("simpleMap", simpleScalaMapType, true) :: - StructField("simpleStruct", simpleScalaStructType, true) :: - StructField("boolean", BooleanType, false) :: Nil) + val complexScalaStructType = SStructType( + SStructField("simpleArray", simpleScalaArrayType, true) :: + SStructField("simpleMap", simpleScalaMapType, true) :: + SStructField("simpleStruct", simpleScalaStructType, true) :: + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) checkDataType(complexScalaStructType) // Complex ArrayType. - val complexScalaArrayType = ArrayType(complexScalaStructType, true) + val complexScalaArrayType = + org.apache.spark.sql.ArrayType(complexScalaStructType, true) checkDataType(complexScalaArrayType) // Complex MapType. - val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false) + val complexScalaMapType = + org.apache.spark.sql.MapType(complexScalaStructType, complexScalaArrayType, false) checkDataType(complexScalaMapType) } } From 4415722e9199d04c2c18bfbd29113ebc40f732f5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 1 Aug 2014 11:27:12 -0700 Subject: [PATCH 087/170] [SQL][SPARK-2212]Hash Outer Join This patch is to support the hash based outer join. Currently, outer join for big relations are resort to `BoradcastNestedLoopJoin`, which is super slow. This PR will create 2 hash tables for both relations in the same partition, which greatly reduce the table scans. Here is the testing code that I used: ``` package org.apache.spark.sql.hive import org.apache.spark.SparkContext import org.apache.spark.SparkConf import org.apache.spark.sql._ case class Record(key: String, value: String) object JoinTablePrepare extends App { import TestHive2._ val rdd = sparkContext.parallelize((1 to 3000000).map(i => Record(s"${i % 828193}", s"val_$i"))) runSqlHive("SHOW TABLES") runSqlHive("DROP TABLE if exists a") runSqlHive("DROP TABLE if exists b") runSqlHive("DROP TABLE if exists result") rdd.registerAsTable("records") runSqlHive("""CREATE TABLE a (key STRING, value STRING) | ROW FORMAT SERDE | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' | STORED AS RCFILE """.stripMargin) runSqlHive("""CREATE TABLE b (key STRING, value STRING) | ROW FORMAT SERDE | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' | STORED AS RCFILE """.stripMargin) runSqlHive("""CREATE TABLE result (key STRING, value STRING) | ROW FORMAT SERDE | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe' | STORED AS RCFILE """.stripMargin) hql(s"""from records | insert into table a | select key, value """.stripMargin) hql(s"""from records | insert into table b select key + 100000, value """.stripMargin) } object JoinTablePerformanceTest extends App { import TestHive2._ hql("SHOW TABLES") hql("set spark.sql.shuffle.partitions=20") val leftOuterJoin = "insert overwrite table result select a.key, b.value from a left outer join b on a.key=b.key" val rightOuterJoin = "insert overwrite table result select a.key, b.value from a right outer join b on a.key=b.key" val fullOuterJoin = "insert overwrite table result select a.key, b.value from a full outer join b on a.key=b.key" val results = ("LeftOuterJoin", benchmark(leftOuterJoin)) :: ("LeftOuterJoin", benchmark(leftOuterJoin)) :: ("RightOuterJoin", benchmark(rightOuterJoin)) :: ("RightOuterJoin", benchmark(rightOuterJoin)) :: ("FullOuterJoin", benchmark(fullOuterJoin)) :: ("FullOuterJoin", benchmark(fullOuterJoin)) :: Nil val explains = hql(s"explain $leftOuterJoin").collect ++ hql(s"explain $rightOuterJoin").collect ++ hql(s"explain $fullOuterJoin").collect println(explains.mkString(",\n")) results.foreach { case (prompt, result) => { println(s"$prompt: took ${result._1} ms (${result._2} records)") } } def benchmark(cmd: String) = { val begin = System.currentTimeMillis() val result = hql(cmd) val end = System.currentTimeMillis() val count = hql("select count(1) from result").collect.mkString("") ((end - begin), count) } } ``` And the result as shown below: ``` [Physical execution plan:], [InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true], [ Project [key#95,value#98]], [ HashOuterJoin [key#95], [key#97], LeftOuter, None], [ Exchange (HashPartitioning [key#95], 20)], [ HiveTableScan [key#95], (MetastoreRelation default, a, None), None], [ Exchange (HashPartitioning [key#97], 20)], [ HiveTableScan [key#97,value#98], (MetastoreRelation default, b, None), None], [Physical execution plan:], [InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true], [ Project [key#102,value#105]], [ HashOuterJoin [key#102], [key#104], RightOuter, None], [ Exchange (HashPartitioning [key#102], 20)], [ HiveTableScan [key#102], (MetastoreRelation default, a, None), None], [ Exchange (HashPartitioning [key#104], 20)], [ HiveTableScan [key#104,value#105], (MetastoreRelation default, b, None), None], [Physical execution plan:], [InsertIntoHiveTable (MetastoreRelation default, result, None), Map(), true], [ Project [key#109,value#112]], [ HashOuterJoin [key#109], [key#111], FullOuter, None], [ Exchange (HashPartitioning [key#109], 20)], [ HiveTableScan [key#109], (MetastoreRelation default, a, None), None], [ Exchange (HashPartitioning [key#111], 20)], [ HiveTableScan [key#111,value#112], (MetastoreRelation default, b, None), None] LeftOuterJoin: took 16072 ms ([3000000] records) LeftOuterJoin: took 14394 ms ([3000000] records) RightOuterJoin: took 14802 ms ([3000000] records) RightOuterJoin: took 14747 ms ([3000000] records) FullOuterJoin: took 17715 ms ([6000000] records) FullOuterJoin: took 17629 ms ([6000000] records) ``` Without this PR, the benchmark will run seems never end. Author: Cheng Hao Closes #1147 from chenghao-intel/hash_based_outer_join and squashes the following commits: 65c599e [Cheng Hao] Fix issues with the community comments 72b1394 [Cheng Hao] Fix bug of stale value in joinedRow 55baef7 [Cheng Hao] Add HashOuterJoin --- .../spark/sql/execution/SparkStrategies.scala | 4 + .../apache/spark/sql/execution/joins.scala | 183 +++++++++++++++++- .../org/apache/spark/sql/JoinSuite.scala | 138 ++++++++++++- 3 files changed, 319 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d57b6eaf40b09..8bec015c7b465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -94,6 +94,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => + execution.HashOuterJoin( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index b068579db75cd..82f0a74b630bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -72,7 +72,7 @@ trait HashJoin { while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() @@ -136,6 +136,185 @@ trait HashJoin { } } +/** + * Constant Value for Binary Join Node + */ +object HashOuterJoin { + val DUMMY_LIST = Seq[Row](null) + val EMPTY_LIST = Seq[Row]() +} + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { + // TODO: Use Spark's HashMap implementation. + val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + existingMatchList += currentRow.copy() + } + + hashTable.toMap[Row, ArrayBuffer[Row]] + } + + def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), + rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + } + case x => throw new Exception(s"Need to add implementation for $x") + } + } + } +} + /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -189,7 +368,7 @@ case class LeftSemiJoinHash( while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) - if(!rowKey.anyNull) { + if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { hashSet.add(rowKey) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 025c396ef0629..037890682f7b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner} +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -class JoinSuite extends QueryTest { +class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData @@ -34,6 +40,56 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } + test("join operator selection") { + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + val cases1 = Seq( + ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a where key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key=2", + classOf[HashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) + // TODO add BroadcastNestedLoopJoin + ) + cases1.foreach { c => assertJoin(c._1, c._2) } + } + test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) @@ -114,6 +170,33 @@ class JoinSuite extends QueryTest { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + (1, "A", null, null) :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) } test("right outer join") { @@ -125,11 +208,38 @@ class JoinSuite extends QueryTest { (4, "d", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + (null, null, 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } test("full outer join") { - val left = upperCaseData.where('N <= 4).as('left) - val right = upperCaseData.where('N >= 3).as('right) + upperCaseData.where('N <= 4).registerAsTable("left") + upperCaseData.where('N >= 3).registerAsTable("right") + + val left = UnresolvedRelation(None, "left", None) + val right = UnresolvedRelation(None, "right", None) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), @@ -139,5 +249,25 @@ class JoinSuite extends QueryTest { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", null, null) :: + (null, null, 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) } } From 580c7011cab6bc93294b6486e778557216bedb10 Mon Sep 17 00:00:00 2001 From: chutium Date: Fri, 1 Aug 2014 11:31:44 -0700 Subject: [PATCH 088/170] [SPARK-2729] [SQL] Forgot to match Timestamp type in ColumnBuilder just a match forgot, found after SPARK-2710 , TimestampType can be used by a SchemaRDD generated from JDBC ResultSet Author: chutium Closes #1636 from chutium/SPARK-2729 and squashes the following commits: 71af77a [chutium] [SPARK-2729] [SQL] added Timestamp in NullableColumnAccessorSuite 39cf9f8 [chutium] [SPARK-2729] add Timestamp Type into ColumnBuilder TestSuite, ref. #1636 ab6ff97 [chutium] [SPARK-2729] Forgot to match Timestamp type in ColumnBuilder --- .../scala/org/apache/spark/sql/columnar/ColumnBuilder.scala | 1 + .../apache/spark/sql/columnar/NullableColumnAccessorSuite.scala | 2 +- .../apache/spark/sql/columnar/NullableColumnBuilderSuite.scala | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 74f5630fbddf1..c416a745739b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -154,6 +154,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 35ab14cbc353d..3baa6f8ec0c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,7 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index d8898527baa39..dc813fe146c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,7 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { + Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { testNullableColumnBuilder(_) } From c0b47bada3c9f0e9e0f14ab41ffb91012a357211 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 1 Aug 2014 11:42:05 -0700 Subject: [PATCH 089/170] [SPARK-2767] [SQL] SparkSQL CLI doens't output error message if query failed. Author: Cheng Hao Closes #1686 from chenghao-intel/spark_sql_cli and squashes the following commits: eb664cc [Cheng Hao] Output detailed failure message in console 93b0382 [Cheng Hao] Fix Bug of no output in cli if exception thrown internally --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 4 +++- .../spark/sql/hive/thriftserver/SparkSQLDriver.scala | 3 +-- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 7 ++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 27268ecb923e9..cb17d7ce58ea0 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -288,8 +288,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { out.println(cmd) } - ret = driver.run(cmd).getResponseCode + val rc = driver.run(cmd) + ret = rc.getResponseCode if (ret != 0) { + console.printError(rc.getErrorMessage()) driver.close() return ret } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 5202aa9903e03..a56b19a4bcda0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -53,10 +53,9 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo } override def run(command: String): CommandProcessorResponse = { - val execution = context.executePlan(context.hql(command).logicalPlan) - // TODO unify the error code try { + val execution = context.executePlan(context.hql(command).logicalPlan) hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 27b444daba2d4..7e3b8727bebed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -131,12 +131,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + ss } - sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") - sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { super.set(key, value) runSqlHive(s"SET $key=$value") From c82fe4781cd0356bcfdd25c7eadf1da624bb2228 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Fri, 1 Aug 2014 11:46:13 -0700 Subject: [PATCH 090/170] [SQL] Documentation: Explain cacheTable command add the `cacheTable` specification Author: CrazyJvm Closes #1681 from CrazyJvm/sql-programming-guide-cache and squashes the following commits: 0a231e0 [CrazyJvm] grammar fixes a04020e [CrazyJvm] modify title to Cached tables 18b6594 [CrazyJvm] fix format 2cbbf58 [CrazyJvm] add cacheTable guide --- docs/sql-programming-guide.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a047d32b6ee6c..7261badd411a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -769,3 +769,13 @@ To start the Spark SQL CLI, run the following in the Spark directory: Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. + +# Cached tables + +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. + +Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in +in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to +cache tables. From eb5bdcaf6c7834558cb76b7132f68b8d94230356 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Fri, 1 Aug 2014 12:04:04 -0700 Subject: [PATCH 091/170] [SPARK-695] In DAGScheduler's getPreferredLocs, track set of visited partitions. getPreferredLocs traverses a dependency graph of partitions using depth first search. Given a complex dependency graph, the old implementation may explore a set of paths in the graph that is exponential in the number of nodes. By maintaining a set of visited nodes the new implementation avoids revisiting nodes, preventing exponential blowup. Some comment and whitespace cleanups are also included. Author: Aaron Staple Closes #1362 from staple/SPARK-695 and squashes the following commits: ecea0f3 [Aaron Staple] address review comments 751c661 [Aaron Staple] [SPARK-695] Add a unit test. 5adf326 [Aaron Staple] Replace getPreferredLocsInternal's HashMap argument with a simpler HashSet. 58e37d0 [Aaron Staple] Replace comment documenting NarrowDependency. 6751ced [Aaron Staple] Revert "Remove unused variable." 04c7097 [Aaron Staple] Fix indentation. 0030884 [Aaron Staple] Remove unused variable. 33f67c6 [Aaron Staple] Clarify comment. 4e42b46 [Aaron Staple] Remove apparently incorrect comment describing NarrowDependency. 65c2d3d [Aaron Staple] [SPARK-695] In DAGScheduler's getPreferredLocs, track set of visited partitions. --- .../scala/org/apache/spark/Dependency.scala | 4 ++-- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/rdd/CoalescedRDD.scala | 4 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 18 +++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 16 +++++++++++++++- 5 files changed, 37 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3935c8772252e..ab2594cfc02eb 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -34,8 +34,8 @@ abstract class Dependency[T] extends Serializable { /** * :: DeveloperApi :: - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. + * Base class for dependencies where each partition of the child RDD depends on a small number + * of partitions of the parent RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5f75c1dd2cb68..368835a867493 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -458,7 +458,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index e7221e3032c11..11ebafbf6d457 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -49,8 +49,8 @@ private[spark] case class CoalescedRDDPartition( } /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations + * Computes the fraction of the parents' partitions containing preferredLocation within + * their getPreferredLocs. * @return locality of this coalesced partition between 0 and 1 */ def localFraction: Double = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c7e3d7c5f8530..5110785de357c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1148,6 +1148,22 @@ class DAGScheduler( */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** Recursive implementation for getPreferredLocs. */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_],Int)]) + : Seq[TaskLocation] = + { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd,partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (!cached.isEmpty) { @@ -1164,7 +1180,7 @@ class DAGScheduler( rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 0ce13d015df05..36e238b4c9434 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -23,6 +23,8 @@ import scala.language.reflectiveCalls import akka.actor._ import akka.testkit.{ImplicitSender, TestKit, TestActorRef} import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -64,7 +66,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike - with ImplicitSender with BeforeAndAfter with LocalSparkContext { + with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -294,6 +296,18 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } + test("avoid exponential blowup when getting preferred locs list") { + // Build up a complex dependency graph with repeated zip operations, without preferred locations. + var rdd: RDD[_] = new MyRDD(sc, 1, Nil) + (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) + // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. + failAfter(10 seconds) { + val preferredLocs = scheduler.getPreferredLocs(rdd,0) + // No preferred locations are returned. + assert(preferredLocs.length === 0) + } + } + test("unserializable task") { val unserializableRdd = new MyRDD(sc, 1, Nil) { class UnserializableClass From baf9ce1a4ecb7acf5accf7a7029f29604b4360c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Aug 2014 12:12:30 -0700 Subject: [PATCH 092/170] [SPARK-2490] Change recursive visiting on RDD dependencies to iterative approach When performing some transformations on RDDs after many iterations, the dependencies of RDDs could be very long. It can easily cause StackOverflowError when recursively visiting these dependencies in Spark core. For example: var rdd = sc.makeRDD(Array(1)) for (i <- 1 to 1000) { rdd = rdd.coalesce(1).cache() rdd.collect() } This PR changes recursive visiting on rdd's dependencies to iterative approach to avoid StackOverflowError. In addition to the recursive visiting, since the Java serializer has a known [bug](http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4152790) that causes StackOverflowError too when serializing/deserializing a large graph of objects. So applying this PR only solves part of the problem. Using KryoSerializer to replace Java serializer might be helpful. However, since KryoSerializer is not supported for `spark.closure.serializer` now, I can not test if KryoSerializer can solve Java serializer's problem completely. Author: Liang-Chi Hsieh Closes #1418 from viirya/remove_recursive_visit and squashes the following commits: 6b2c615 [Liang-Chi Hsieh] change function name; comply with code style. 5f072a7 [Liang-Chi Hsieh] add comments to explain Stack usage. 8742dbb [Liang-Chi Hsieh] comply with code style. 900538b [Liang-Chi Hsieh] change recursive visiting on rdd's dependencies to iterative approach to avoid stackoverflowerror. --- .../apache/spark/scheduler/DAGScheduler.scala | 83 +++++++++++++++++-- 1 file changed, 75 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5110785de357c..d87c3048985fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -21,7 +21,7 @@ import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -211,11 +211,15 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => + // We are going to register ancestor shuffle dependencies + registerShuffleDependencies(shuffleDep, jobId) + // Then register current shuffleDep val stage = newOrUsedStage( shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage + stage } } @@ -280,6 +284,9 @@ class DAGScheduler( private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(r: RDD[_]) { if (!visited(r)) { visited += r @@ -290,18 +297,69 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => - visit(dep.rdd) + waitingForVisit.push(dep.rdd) } } } } - visit(rdd) + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } parents.toList } + // Find ancestor missing shuffle dependencies and register into shuffleToMapStage + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) + while (!parentsWithNoMapStage.isEmpty) { + val currentShufDep = parentsWithNoMapStage.pop() + val stage = + newOrUsedStage( + currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, + currentShufDep.rdd.creationSite) + shuffleToMapStage(currentShufDep.shuffleId) = stage + } + } + + // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { + val parents = new Stack[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + if (!shuffleToMapStage.contains(shufDep.shuffleId)) { + parents.push(shufDep) + } + + waitingForVisit.push(shufDep.rdd) + case _ => + waitingForVisit.push(dep.rdd) + } + } + } + } + + waitingForVisit.push(rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } + parents + } + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -314,13 +372,16 @@ class DAGScheduler( missing += mapStage } case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } missing.toList } @@ -1119,6 +1180,9 @@ class DAGScheduler( } val visitedRdds = new HashSet[RDD[_]] val visitedStages = new HashSet[Stage] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new Stack[RDD[_]] def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -1128,15 +1192,18 @@ class DAGScheduler( val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage - visit(mapStage.rdd) + waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) + waitingForVisit.push(narrowDep.rdd) } } } } - visit(stage.rdd) + waitingForVisit.push(stage.rdd) + while (!waitingForVisit.isEmpty) { + visit(waitingForVisit.pop()) + } visitedRdds.contains(target.rdd) } From f5d9bea20e0db22c09c1191ca44a6471de765739 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 1 Aug 2014 13:25:04 -0700 Subject: [PATCH 093/170] SPARK-1612: Fix potential resource leaks JIRA: https://issues.apache.org/jira/browse/SPARK-1612 Move the "close" statements into a "finally" block. Author: zsxwing Closes #535 from zsxwing/SPARK-1612 and squashes the following commits: ae52f50 [zsxwing] Update to follow the code style 549ba13 [zsxwing] SPARK-1612: Fix potential resource leaks --- .../scala/org/apache/spark/util/Utils.scala | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f8fbb3ad6d4a1..30073a82857d2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -286,17 +286,23 @@ private[spark] object Utils extends Logging { out: OutputStream, closeStreams: Boolean = false) { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + try { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + } finally { + if (closeStreams) { + try { + in.close() + } finally { + out.close() + } } - } - if (closeStreams) { - in.close() - out.close() } } @@ -868,9 +874,12 @@ private[spark] object Utils extends Logging { val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) val stream = new FileInputStream(file) - stream.skip(effectiveStart) - stream.read(buff) - stream.close() + try { + stream.skip(effectiveStart) + stream.read(buff) + } finally { + stream.close() + } Source.fromBytes(buff).mkString } From b270309d7608fb749e402cd5afd36087446be398 Mon Sep 17 00:00:00 2001 From: joyyoj Date: Fri, 1 Aug 2014 13:41:55 -0700 Subject: [PATCH 094/170] [SPARK-2379] Fix the bug that streaming's receiver may fall into a dead loop Author: joyyoj Closes #1694 from joyyoj/SPARK-2379 and squashes the following commits: d73790d [joyyoj] SPARK-2379 Fix the bug that streaming's receiver may fall into a dead loop 22e7821 [joyyoj] Merge remote-tracking branch 'apache/master' 3f4a602 [joyyoj] Merge remote-tracking branch 'remotes/apache/master' f4660c5 [joyyoj] [SPARK-1998] SparkFlumeEvent with body bigger than 1020 bytes are not read properly --- .../apache/spark/streaming/receiver/ReceiverSupervisor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 09be3a50d2dfa..1f0244c251eba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -138,7 +138,7 @@ private[streaming] abstract class ReceiverSupervisor( onReceiverStop(message, error) } catch { case t: Throwable => - stop("Error stopping receiver " + streamId, Some(t)) + logError("Error stopping receiver " + streamId + t.getStackTraceString) } } From 78f2af582286b81e6dc9fa9d455ed2b369d933bd Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 1 Aug 2014 13:57:19 -0700 Subject: [PATCH 095/170] SPARK-2791: Fix committing, reverting and state tracking in shuffle file consolidation All changes from this PR are by mridulm and are drawn from his work in #1609. This patch is intended to fix all major issues related to shuffle file consolidation that mridulm found, while minimizing changes to the code, with the hope that it may be more easily merged into 1.1. This patch is **not** intended as a replacement for #1609, which provides many additional benefits, including fixes to ExternalAppendOnlyMap, improvements to DiskBlockObjectWriter's API, and several new unit tests. If it is feasible to merge #1609 for the 1.1 deadline, that is a preferable option. Author: Aaron Davidson Closes #1678 from aarondav/consol and squashes the following commits: 53b3f6d [Aaron Davidson] Correct behavior when writing unopened file 701d045 [Aaron Davidson] Rebase with sort-based shuffle 9160149 [Aaron Davidson] SPARK-2532: Minimal shuffle consolidation fixes --- .../shuffle/hash/HashShuffleWriter.scala | 14 +-- .../shuffle/sort/SortShuffleWriter.scala | 3 +- .../spark/storage/BlockObjectWriter.scala | 53 ++++++----- .../spark/storage/ShuffleBlockManager.scala | 28 +++--- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../util/collection/ExternalSorter.scala | 6 +- .../spark/storage/DiskBlockManagerSuite.scala | 87 ++++++++++++++++++- .../spark/tools/StoragePerfTester.scala | 5 +- 8 files changed, 146 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 1923f7c71a48f..45d3b8b9b8725 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -65,7 +65,8 @@ private[spark] class HashShuffleWriter[K, V]( } /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { + override def stop(initiallySuccess: Boolean): Option[MapStatus] = { + var success = initiallySuccess try { if (stopping) { return None @@ -73,15 +74,16 @@ private[spark] class HashShuffleWriter[K, V]( stopping = true if (success) { try { - return Some(commitWritesAndBuildStatus()) + Some(commitWritesAndBuildStatus()) } catch { case e: Exception => + success = false revertWrites() throw e } } else { revertWrites() - return None + None } } finally { // Release the writers back to the shuffle block manager. @@ -100,8 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( var totalBytes = 0L var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() + writer.commitAndClose() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -120,8 +121,7 @@ private[spark] class HashShuffleWriter[K, V]( private def revertWrites(): Unit = { if (shuffle != null && shuffle.writers != null) { for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + writer.revertPartialWritesAndClose() } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 42fcd07fa18bc..9a356d0dbaf17 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -94,8 +94,7 @@ private[spark] class SortShuffleWriter[K, V, C]( for (elem <- elements) { writer.write(elem) } - writer.commit() - writer.close() + writer.commitAndClose() val segment = writer.fileSegment() offsets(id + 1) = segment.offset + segment.length lengths(id) = segment.length diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a2687e6be4e34..01d46e1ffc960 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -39,16 +39,16 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def isOpen: Boolean /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. + * Flush the partial writes and commit them as a single atomic block. */ - def commit(): Long + def commitAndClose(): Unit /** * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. */ - def revertPartialWrites() + def revertPartialWritesAndClose() /** * Writes an object. @@ -57,6 +57,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { /** * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment @@ -108,7 +109,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private val initialPosition = file.length() - private var lastValidPosition = initialPosition + private var finalPosition: Long = -1 private var initialized = false private var _timeWriting = 0L @@ -116,7 +117,6 @@ private[spark] class DiskBlockObjectWriter( fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() - lastValidPosition = initialPosition bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true @@ -147,28 +147,36 @@ private[spark] class DiskBlockObjectWriter( override def isOpen: Boolean = objOut != null - override def commit(): Long = { + override def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition + close() } + finalPosition = file.length() } - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + override def revertPartialWritesAndClose() { + try { + if (initialized) { + objOut.flush() + bs.flush() + close() + } + + val truncateStream = new FileOutputStream(file, true) + try { + truncateStream.getChannel.truncate(initialPosition) + } finally { + truncateStream.close() + } + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) } } @@ -188,6 +196,7 @@ private[spark] class DiskBlockObjectWriter( // Only valid if called after commit() override def bytesWritten: Long = { - lastValidPosition - initialPosition + assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") + finalPosition - initialPosition } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 7beb55c411e71..28aa35bc7e147 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -144,7 +144,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { if (consolidateShuffleFiles) { if (success) { val offsets = writers.map(_.fileSegment().offset) - fileGroup.recordMapOutput(mapId, offsets) + val lengths = writers.map(_.fileSegment().length) + fileGroup.recordMapOutput(mapId, offsets, lengths) } recycleFileGroup(fileGroup) } else { @@ -247,6 +248,8 @@ object ShuffleBlockManager { * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. */ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + private var numBlocks: Int = 0 + /** * Stores the absolute index of each mapId in the files of this group. For instance, * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. @@ -254,23 +257,27 @@ object ShuffleBlockManager { private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() /** - * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. - * This ordering allows us to compute block lengths by examining the following block offset. + * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by + * position in the file. * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every * reducer. */ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { new PrimitiveVector[Long]() } - - def numBlocks = mapIdToIndex.size + private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } def apply(bucketId: Int) = files(bucketId) - def recordMapOutput(mapId: Int, offsets: Array[Long]) { + def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { + assert(offsets.length == lengths.length) mapIdToIndex(mapId) = numBlocks + numBlocks += 1 for (i <- 0 until offsets.length) { blockOffsetsByReducer(i) += offsets(i) + blockLengthsByReducer(i) += lengths(i) } } @@ -278,16 +285,11 @@ object ShuffleBlockManager { def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { val file = files(reducerId) val blockOffsets = blockOffsetsByReducer(reducerId) + val blockLengths = blockLengthsByReducer(reducerId) val index = mapIdToIndex.getOrElse(mapId, -1) if (index >= 0) { val offset = blockOffsets(index) - val length = - if (index + 1 < numBlocks) { - blockOffsets(index + 1) - offset - } else { - file.length() - offset - } - assert(length >= 0) + val length = blockLengths(index) Some(new FileSegment(file, offset, length)) } else { None diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index b34512ef9eb60..cb67a1c039f20 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -199,7 +199,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables def flush() = { - writer.commit() + writer.commitAndClose() val bytesWritten = writer.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 54c3310744136..6e415a2bd8ce2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -270,9 +270,10 @@ private[spark] class ExternalSorter[K, V, C]( // How many elements we have in each partition val elementsPerPartition = new Array[Long](numPartitions) - // Flush the disk writer's contents to disk, and update relevant variables + // Flush the disk writer's contents to disk, and update relevant variables. + // The writer is closed at the end of this process, and cannot be reused. def flush() = { - writer.commit() + writer.commitAndClose() val bytesWritten = writer.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten @@ -293,7 +294,6 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer.close() writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index aaa7714049732..985ac9394738c 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -22,11 +22,14 @@ import java.io.{File, FileWriter} import scala.collection.mutable import scala.language.reflectiveCalls +import akka.actor.Props import com.google.common.io.Files import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.util.Utils +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{AkkaUtils, Utils} class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -121,6 +124,88 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before newFile.delete() } + private def checkSegments(segment1: FileSegment, segment2: FileSegment) { + assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath) + assert (segment1.offset === segment2.offset) + assert (segment1.length === segment2.length) + } + + test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { + + val serializer = new JavaSerializer(testConf) + val confCopy = testConf.clone + // reset after EACH object write. This is to ensure that there are bytes appended after + // an object is written. So if the codepaths assume writeObject is end of data, this should + // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. + confCopy.set("spark.serializer.objectStreamReset", "1") + + val securityManager = new org.apache.spark.SecurityManager(confCopy) + // Do not use the shuffleBlockManager above ! + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy, + securityManager) + val master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), + confCopy) + val store = new BlockManager("", actorSystem, master , serializer, confCopy, + securityManager, null) + + try { + + val shuffleManager = store.shuffleBlockManager + + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + for (writer <- shuffle1.writers) { + writer.write("test1") + writer.write("test2") + } + for (writer <- shuffle1.writers) { + writer.commitAndClose() + } + + val shuffle1Segment = shuffle1.writers(0).fileSegment() + shuffle1.releaseWriters(success = true) + + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + + for (writer <- shuffle2.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle2.writers) { + writer.commitAndClose() + } + val shuffle2Segment = shuffle2.writers(0).fileSegment() + shuffle2.releaseWriters(success = true) + + // Now comes the test : + // Write to shuffle 3; and close it, but before registering it, check if the file lengths for + // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length + // of block based on remaining data in file : which could mess things up when there is concurrent read + // and writes happening to the same shuffle group. + + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + for (writer <- shuffle3.writers) { + writer.write("test3") + writer.write("test4") + } + for (writer <- shuffle3.writers) { + writer.commitAndClose() + } + // check before we register. + checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) + shuffle3.releaseWriters(success = true) + checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0))) + shuffleManager.removeShuffle(1) + } finally { + + if (store != null) { + store.stop() + } + actorSystem.shutdown() + actorSystem.awaitTermination() + } + } + def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { val segment = diskBlockManager.getBlockLocation(blockId) assert(segment.file.getName === filename) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8e8c35615a711..8a05fcb449aa6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -61,10 +61,9 @@ object StoragePerfTester { for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) } - writers.map {w => - w.commit() + writers.map { w => + w.commitAndClose() total.addAndGet(w.fileSegment().length) - w.close() } shuffle.releaseWriters(true) From d88e69561367d65e1a2b94527b80a1f65a2cba90 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Fri, 1 Aug 2014 15:02:17 -0700 Subject: [PATCH 096/170] [SPARK-2786][mllib] Python correlations Author: Doris Xin Closes #1713 from dorx/pythonCorrelation and squashes the following commits: 5f1e60c [Doris Xin] reviewer comments. 46ff6eb [Doris Xin] reviewer comments. ad44085 [Doris Xin] style fix e69d446 [Doris Xin] fixed missed conflicts. eb5bf56 [Doris Xin] merge master cc9f725 [Doris Xin] units passed. 9141a63 [Doris Xin] WIP2 d199f1f [Doris Xin] Moved correlation names into a public object cd163d6 [Doris Xin] WIP --- .../mllib/api/python/PythonMLLibAPI.scala | 39 ++++++- .../apache/spark/mllib/stat/Statistics.scala | 10 +- .../mllib/stat/correlation/Correlation.scala | 49 +++++---- .../api/python/PythonMLLibAPISuite.scala | 21 +++- python/pyspark/mllib/_common.py | 6 +- python/pyspark/mllib/stat.py | 104 ++++++++++++++++++ 6 files changed, 199 insertions(+), 30 deletions(-) create mode 100644 python/pyspark/mllib/stat.py diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d2e8ccf208970..122925d096e98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -20,13 +20,15 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -227,7 +229,7 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint).toJavaRDD() + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, @@ -456,6 +458,37 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). + * Returns the correlation matrix serialized into a byte array understood by deserializers in + * pyspark. + */ + def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { + val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) + serializeDoubleMatrix(to2dArray(result)) + } + + /** + * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). + */ + def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { + val xDeser = x.rdd.map(deserializeDouble(_)) + val yDeser = y.rdd.map(deserializeDouble(_)) + Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) + } + + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark + private def getCorrNameOrDefault(method: String) = { + if (method == null) CorrelationNames.defaultCorrName else method + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + // Used by the *RDD methods to get default seed if not passed in from pyspark private def getSeedOrDefault(seed: java.lang.Long): Long = { if (seed == null) Utils.random.nextLong else seed diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 9d6de9b6e1f60..f416a9fbb323d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -23,21 +23,24 @@ import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.rdd.RDD /** - * API for statistical functions in MLlib + * API for statistical functions in MLlib. */ @Experimental object Statistics { /** + * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** + * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -51,9 +54,11 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ + @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** + * :: Experimental :: * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * @@ -64,9 +69,11 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ + @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** + * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -80,5 +87,6 @@ object Statistics { *@return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ + @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala index f23393d3da257..1fb8d7b3d4f32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala @@ -49,43 +49,48 @@ private[stat] trait Correlation { } /** - * Delegates computation to the specific correlation object based on the input method name - * - * Currently supported correlations: pearson, spearman. - * After new correlation algorithms are added, please update the documentation here and in - * Statistics.scala for the correlation APIs. - * - * Maintains the default correlation type, pearson + * Delegates computation to the specific correlation object based on the input method name. */ private[stat] object Correlations { - // Note: after new types of correlations are implemented, please update this map - val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) - val defaultCorrName: String = "pearson" - val defaultCorr: Correlation = nameToObjectMap(defaultCorrName) - - def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = { + def corr(x: RDD[Double], + y: RDD[Double], + method: String = CorrelationNames.defaultCorrName): Double = { val correlation = getCorrelationFromName(method) correlation.computeCorrelation(x, y) } - def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = { + def corrMatrix(X: RDD[Vector], + method: String = CorrelationNames.defaultCorrName): Matrix = { val correlation = getCorrelationFromName(method) correlation.computeCorrelationMatrix(X) } - /** - * Match input correlation name with a known name via simple string matching - * - * private to stat for ease of unit testing - */ - private[stat] def getCorrelationFromName(method: String): Correlation = { + // Match input correlation name with a known name via simple string matching. + def getCorrelationFromName(method: String): Correlation = { try { - nameToObjectMap(method) + CorrelationNames.nameToObjectMap(method) } catch { case nse: NoSuchElementException => throw new IllegalArgumentException("Unrecognized method name. Supported correlations: " - + nameToObjectMap.keys.mkString(", ")) + + CorrelationNames.nameToObjectMap.keys.mkString(", ")) } } } + +/** + * Maintains supported and default correlation names. + * + * Currently supported correlations: `pearson`, `spearman`. + * Current default correlation: `pearson`. + * + * After new correlation algorithms are added, please update the documentation here and in + * Statistics.scala for the correlation APIs. + */ +private[mllib] object CorrelationNames { + + // Note: after new types of correlations are implemented, please update this map. + val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) + val defaultCorrName: String = "pearson" + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index d94cfa2fcec81..bd413a80f5107 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { @@ -59,10 +59,25 @@ class PythonMLLibAPISuite extends FunSuite { } test("double serialization") { - for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) { + for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { val bytes = py.serializeDouble(x) val deser = py.deserializeDouble(bytes) - assert(x === deser) + // We use `equals` here for comparison because we cannot use `==` for NaN + assert(x.equals(deser)) } } + + test("matrix to 2D array") { + val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) + val matrix = Matrices.dense(2, 3, values) + val arr = py.to2dArray(matrix) + val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) + assert(arr === expected) + + // Test conversion for empty matrix + val empty = Array[Double]() + val emptyMatrix = Matrices.dense(0, 0, empty) + val empty2D = py.to2dArray(emptyMatrix) + assert(empty2D === Array[Array[Double]]()) + } } diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 8e3ad6b783b6c..c6ca6a75df746 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -101,7 +101,7 @@ def _serialize_double(d): """ Serialize a double (float or numpy.float64) into a mutually understood format. """ - if type(d) == float or type(d) == float64: + if type(d) == float or type(d) == float64 or type(d) == int or type(d) == long: d = float64(d) ba = bytearray(8) _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64) @@ -176,6 +176,10 @@ def _deserialize_double(ba, offset=0): True >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0 True + >>> _deserialize_double(_serialize_double(1)) == 1.0 + True + >>> _deserialize_double(_serialize_double(1L)) == 1.0 + True >>> x = sys.float_info.max >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x True diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py new file mode 100644 index 0000000000000..0a08a562d1f1f --- /dev/null +++ b/python/pyspark/mllib/stat.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib._common import \ + _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ + _serialize_double, _serialize_double_vector, \ + _deserialize_double, _deserialize_double_matrix + +class Statistics(object): + + @staticmethod + def corr(x, y=None, method=None): + """ + Compute the correlation (matrix) for the input RDD(s) using the + specified method. + Methods currently supported: I{pearson (default), spearman}. + + If a single RDD of Vectors is passed in, a correlation matrix + comparing the columns in the input RDD is returned. Use C{method=} + to specify the method to be used for single RDD inout. + If two RDDs of floats are passed in, a single float is returned. + + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) + >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) + >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) + >>> abs(Statistics.corr(x, y) - 0.6546537) < 1e-7 + True + >>> Statistics.corr(x, y) == Statistics.corr(x, y, "pearson") + True + >>> Statistics.corr(x, y, "spearman") + 0.5 + >>> from math import isnan + >>> isnan(Statistics.corr(x, zeros)) + True + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) + >>> Statistics.corr(rdd) + array([[ 1. , 0.05564149, nan, 0.40047142], + [ 0.05564149, 1. , nan, 0.91359586], + [ nan, nan, 1. , nan], + [ 0.40047142, 0.91359586, nan, 1. ]]) + >>> Statistics.corr(rdd, method="spearman") + array([[ 1. , 0.10540926, nan, 0.4 ], + [ 0.10540926, 1. , nan, 0.9486833 ], + [ nan, nan, 1. , nan], + [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> try: + ... Statistics.corr(rdd, "spearman") + ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... except TypeError: + ... pass + """ + sc = x.ctx + # Check inputs to determine whether a single value or a matrix is needed for output. + # Since it's legal for users to use the method name as the second argument, we need to + # check if y is used to specify the method name instead. + if type(y) == str: + raise TypeError("Use 'method=' to specify method name.") + if not y: + try: + Xser = _get_unmangled_double_vector_rdd(x) + except TypeError: + raise TypeError("corr called on a single RDD not consisted of Vectors.") + resultMat = sc._jvm.PythonMLLibAPI().corr(Xser._jrdd, method) + return _deserialize_double_matrix(resultMat) + else: + xSer = _get_unmangled_rdd(x, _serialize_double) + ySer = _get_unmangled_rdd(y, _serialize_double) + result = sc._jvm.PythonMLLibAPI().corr(xSer._jrdd, ySer._jrdd, method) + return result + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 7058a5393bccc2f917189fa9b4cf7f314410b0de Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Aug 2014 15:52:21 -0700 Subject: [PATCH 097/170] [SPARK-2796] [mllib] DecisionTree bug fix: ordered categorical features Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2" Bug fix: Modified upper bound discussed above. Also: Small improvements to coding style in DecisionTree. CC mengxr manishamde Author: Joseph K. Bradley Closes #1720 from jkbradley/decisiontree-bugfix2 and squashes the following commits: 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. --- .../spark/mllib/tree/DecisionTree.scala | 45 +++++++++++-------- .../spark/mllib/tree/DecisionTreeSuite.scala | 29 ++++++++++++ 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7d123dd6ae996..382e76a9b7cba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging { val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)){ + if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid } else if (lowThreshold >= feature) { @@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature. + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). */ - def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureValue = labeledPoint.features(featureIndex) var binIndex = 0 - while (binIndex < numCategoricalBins) { + while (binIndex < featureCategories) { val bin = bins(featureIndex)(binIndex) val categories = bin.highSplit.categories - val features = labeledPoint.features - if (categories.contains(features(featureIndex))) { + if (categories.contains(featureValue)) { return binIndex } binIndex += 1 } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } -1 } if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for continuous variable.") } binIndex @@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging { if (isUnorderedFeature) { sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForOrderedCategoricalFeatureInClassification() + sequentialBinSearchForOrderedCategoricalFeature() } } - if (binIndex == -1){ + if (binIndex == -1) { throw new UnknownError("no bin was found for categorical variable.") } binIndex @@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging { val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + val aggIndex = + numClasses * numBins * numFeatures * nodeIndex + + numClasses * numBins * featureIndex + + numClasses * arr(arrIndex).toInt + + label.toInt + agg(aggIndex) += 1 } /** @@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures){ + if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1){ + for (index <- 1 until numBins - 1) { val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 10462db700628..546a132559326 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(accuracy >= requiredAccuracy) } + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE) + } + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } + test("regression stump with categorical variables of arity 2") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 2, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + + val model = DecisionTree.train(rdd, strategy) + validateRegressor(model, arr, 0.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + } + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) From 880eabec37c69ce4e9594d7babfac291b0f93f50 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 18:47:41 -0700 Subject: [PATCH 098/170] [SPARK-2010] [PySpark] [SQL] support nested structure in SchemaRDD Convert Row in JavaSchemaRDD into Array[Any] and unpickle them as tuple in Python, then convert them into namedtuple, so use can access fields just like attributes. This will let nested structure can be accessed as object, also it will reduce the size of serialized data and better performance. root |-- field1: integer (nullable = true) |-- field2: string (nullable = true) |-- field3: struct (nullable = true) | |-- field4: integer (nullable = true) | |-- field5: array (nullable = true) | | |-- element: integer (containsNull = false) |-- field6: array (nullable = true) | |-- element: struct (containsNull = false) | | |-- field7: string (nullable = true) Then we can access them by row.field3.field5[0] or row.field6[5].field7 It also will infer the schema in Python, convert Row/dict/namedtuple/objects into tuple before serialization, then call applySchema in JVM. During inferSchema(), the top level of dict in row will be StructType, but any nested dictionary will be MapType. You can use pyspark.sql.Row to convert unnamed structure into Row object, make the RDD can be inferable. Such as: ctx.inferSchema(rdd.map(lambda x: Row(a=x[0], b=x[1])) Or you could use Row to create a class just like namedtuple, for example: Person = Row("name", "age") ctx.inferSchema(rdd.map(lambda x: Person(*x))) Also, you can call applySchema to apply an schema to a RDD of tuple/list and turn it into a SchemaRDD. The `schema` should be StructType, see the API docs for details. schema = StructType([StructField("name, StringType, True), StructType("age", IntegerType, True)]) ctx.applySchema(rdd, schema) PS: In order to use namedtuple to inferSchema, you should make namedtuple picklable. Author: Davies Liu Closes #1598 from davies/nested and squashes the following commits: f1d15b6 [Davies Liu] verify schema with the first few rows 8852aaf [Davies Liu] check type of schema abe9e6e [Davies Liu] address comments 61b2292 [Davies Liu] add @deprecated to pythonToJavaMap 1e5b801 [Davies Liu] improve cache of classes 51aa135 [Davies Liu] use Row to infer schema e9c0d5c [Davies Liu] remove string typed schema 353a3f2 [Davies Liu] fix code style 63de8f8 [Davies Liu] fix typo c79ca67 [Davies Liu] fix serialization of nested data 6b258b5 [Davies Liu] fix pep8 9d8447c [Davies Liu] apply schema provided by string of names f5df97f [Davies Liu] refactor, address comments 9d9af55 [Davies Liu] use arrry to applySchema and infer schema in Python 84679b3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into nested 0eaaf56 [Davies Liu] fix doc tests b3559b4 [Davies Liu] use generated Row instead of namedtuple c4ddc30 [Davies Liu] fix conflict between name of fields and variables 7f6f251 [Davies Liu] address all comments d69d397 [Davies Liu] refactor 2cc2d45 [Davies Liu] refactor 182fb46 [Davies Liu] refactor bc6e9e1 [Davies Liu] switch to new Schema API 547bf3e [Davies Liu] Merge branch 'master' into nested a435b5a [Davies Liu] add docs and code refactor 2c8debc [Davies Liu] Merge branch 'master' into nested 644665a [Davies Liu] use tuple and namedtuple for schemardd --- .../apache/spark/api/python/PythonRDD.scala | 69 +- python/pyspark/rdd.py | 8 +- python/pyspark/sql.py | 1258 ++++++++++++----- .../org/apache/spark/sql/SQLContext.scala | 87 +- .../org/apache/spark/sql/SchemaRDD.scala | 18 +- 5 files changed, 996 insertions(+), 444 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 94d666aa92025..fe9a9e50ef21d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,7 +25,7 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.language.existentials import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -536,25 +536,6 @@ private[spark] object PythonRDD extends Logging { file.close() } - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - * It is only used by pyspark.sql. - * TODO: Support more Python types. - */ - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - private def getMergedConf(confAsMap: java.util.HashMap[String, String], baseConf: Configuration): Configuration = { val conf = PythonHadoopUtil.mapToConf(confAsMap) @@ -701,6 +682,54 @@ private[spark] object PythonRDD extends Logging { } } + + /** + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * This function is outdated, PySpark does not use it anymore + */ + @deprecated + def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + unpickle.loads(row) match { + // in case of objects are pickled in batch mode + case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) + // not in batch mode + case obj: JMap[String @unchecked, _] => Seq(obj.toMap) + } + } + } + } + + /** + * Convert an RDD of serialized Python tuple to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { + + def toArray(obj: Any): Array[_] = { + obj match { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + } + } + + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].map(toArray) + } else { + Seq(toArray(obj)) + } + } + }.toJavaRDD() + } + /** * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by * PySpark. diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e8fcc900efb24..309f5a9b6038d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -318,9 +318,9 @@ def map(self, f, preservesPartitioning=False): >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] """ - def func(split, iterator): + def func(_, iterator): return imap(f, iterator) - return PipelinedRDD(self, func, preservesPartitioning) + return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): """ @@ -1184,7 +1184,7 @@ def func(split, iterator): if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(self, func) + keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) @@ -1382,7 +1382,7 @@ def add_shuffle_key(split, iterator): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = PipelinedRDD(self, add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: pairRDD = self.ctx._jvm.PairwiseRDD( diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 9388ead5eaad3..f840475ffaf70 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,7 +15,17 @@ # limitations under the License. # + +import sys +import types +import itertools +import warnings +import decimal +import datetime +import keyword import warnings +from array import array +from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer @@ -26,10 +36,30 @@ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", + "SchemaRDD", "Row"] + + +class DataType(object): + """Spark SQL DataType""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" + _instances = {} def __call__(cls): @@ -38,148 +68,105 @@ def __call__(cls): return cls._instances[cls] -class StringType(object): +class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + def __eq__(self, other): + # because they should be the same object + return self is other + + +class StringType(PrimitiveType): """Spark SQL StringType The data type representing string values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "StringType" -class BinaryType(object): +class BinaryType(PrimitiveType): """Spark SQL BinaryType The data type representing bytearray values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "BinaryType" -class BooleanType(object): +class BooleanType(PrimitiveType): """Spark SQL BooleanType The data type representing bool values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "BooleanType" -class TimestampType(object): +class TimestampType(PrimitiveType): """Spark SQL TimestampType The data type representing datetime.datetime values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "TimestampType" -class DecimalType(object): +class DecimalType(PrimitiveType): """Spark SQL DecimalType The data type representing decimal.Decimal values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "DecimalType" -class DoubleType(object): +class DoubleType(PrimitiveType): """Spark SQL DoubleType The data type representing float values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "DoubleType" -class FloatType(object): +class FloatType(PrimitiveType): """Spark SQL FloatType The data type representing single precision floating-point values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "FloatType" - -class ByteType(object): +class ByteType(PrimitiveType): """Spark SQL ByteType The data type representing int values with 1 singed byte. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "ByteType" -class IntegerType(object): +class IntegerType(PrimitiveType): """Spark SQL IntegerType The data type representing int values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "IntegerType" - -class LongType(object): +class LongType(PrimitiveType): """Spark SQL LongType - The data type representing long values. If the any value is beyond the range of - [-9223372036854775808, 9223372036854775807], please use DecimalType. - + The data type representing long values. If the any value is + beyond the range of [-9223372036854775808, 9223372036854775807], + please use DecimalType. """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "LongType" - -class ShortType(object): +class ShortType(PrimitiveType): """Spark SQL ShortType The data type representing int values with 2 signed bytes. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "ShortType" -class ArrayType(object): +class ArrayType(DataType): """Spark SQL ArrayType - The data type representing list values. - An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool). + The data type representing list values. An ArrayType object + comprises two fields, elementType (a DataType) and containsNull (a bool). The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has None values. """ + def __init__(self, elementType, containsNull=False): """Creates an ArrayType @@ -194,40 +181,39 @@ def __init__(self, elementType, containsNull=False): self.elementType = elementType self.containsNull = containsNull - def __repr__(self): - return "ArrayType(" + self.elementType.__repr__() + "," + \ - str(self.containsNull).lower() + ")" - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.elementType == other.elementType and - self.containsNull == other.containsNull) - - def __ne__(self, other): - return not self.__eq__(other) + def __str__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) -class MapType(object): +class MapType(DataType): """Spark SQL MapType - The data type representing dict values. - A MapType object comprises three fields, - keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool). + The data type representing dict values. A MapType object comprises + three fields, keyType (a DataType), valueType (a DataType) and + valueContainsNull (a bool). + The field of keyType is used to specify the type of keys in the map. The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this map has None values. + The field of valueContainsNull is used to specify if values of this + map has None values. + For values of a MapType column, keys are not allowed to have None values. """ + def __init__(self, keyType, valueType, valueContainsNull=True): """Creates a MapType :param keyType: the data type of keys. :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains null values. + :param valueContainsNull: indicates whether values contains + null values. - >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) + >>> (MapType(StringType, IntegerType) + ... == MapType(StringType, IntegerType, True)) True - >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType) + >>> (MapType(StringType, IntegerType, False) + ... == MapType(StringType, FloatType)) False """ self.keyType = keyType @@ -235,39 +221,36 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueContainsNull = valueContainsNull def __repr__(self): - return "MapType(" + self.keyType.__repr__() + "," + \ - self.valueType.__repr__() + "," + \ - str(self.valueContainsNull).lower() + ")" + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.keyType == other.keyType and - self.valueType == other.valueType and - self.valueContainsNull == other.valueContainsNull) - def __ne__(self, other): - return not self.__eq__(other) - - -class StructField(object): +class StructField(DataType): """Spark SQL StructField Represents a field in a StructType. - A StructField object comprises three fields, name (a string), dataType (a DataType), - and nullable (a bool). The field of name is the name of a StructField. The field of - dataType specifies the data type of a StructField. - The field of nullable specifies if values of a StructField can contain None values. + A StructField object comprises three fields, name (a string), + dataType (a DataType) and nullable (a bool). The field of name + is the name of a StructField. The field of dataType specifies + the data type of a StructField. + + The field of nullable specifies if values of a StructField can + contain None values. """ + def __init__(self, name, dataType, nullable): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. - :param nullable: indicates whether values of this field can be null. + :param nullable: indicates whether values of this field + can be null. - >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + >>> (StructField("f1", StringType, True) + ... == StructField("f1", StringType, True)) True - >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + >>> (StructField("f1", StringType, True) + ... == StructField("f2", StringType, True)) False """ self.name = name @@ -275,27 +258,18 @@ def __init__(self, name, dataType, nullable): self.nullable = nullable def __repr__(self): - return "StructField(" + self.name + "," + \ - self.dataType.__repr__() + "," + \ - str(self.nullable).lower() + ")" + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.name == other.name and - self.dataType == other.dataType and - self.nullable == other.nullable) - def __ne__(self, other): - return not self.__eq__(other) - - -class StructType(object): +class StructType(DataType): """Spark SQL StructType - The data type representing namedtuple values. + The data type representing rows. A StructType object comprises a list of L{StructField}s. """ + def __init__(self, fields): """Creates a StructType @@ -312,15 +286,8 @@ def __init__(self, fields): self.fields = fields def __repr__(self): - return "StructType(List(" + \ - ",".join([field.__repr__() for field in self.fields]) + "))" - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.fields == other.fields) - - def __ne__(self, other): - return not self.__eq__(other) + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): @@ -347,34 +314,19 @@ def _parse_datatype_list(datatype_list_string): return datatype_list +_all_primitive_types = dict((k, v) for k, v in globals().iteritems() + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + + def _parse_datatype_string(datatype_string): """Parses the given data type string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) - ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) + ... python_datatype = _parse_datatype_string( + ... scala_datatype.toString()) ... return datatype == python_datatype - >>> check_datatype(StringType()) - True - >>> check_datatype(BinaryType()) - True - >>> check_datatype(BooleanType()) - True - >>> check_datatype(TimestampType()) - True - >>> check_datatype(DecimalType()) - True - >>> check_datatype(DoubleType()) - True - >>> check_datatype(FloatType()) - True - >>> check_datatype(ByteType()) - True - >>> check_datatype(IntegerType()) - True - >>> check_datatype(LongType()) - True - >>> check_datatype(ShortType()) + >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) @@ -405,70 +357,525 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_arraytype) True >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False) + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) >>> check_datatype(complex_maptype) True """ - left_bracket_index = datatype_string.find("(") - if left_bracket_index == -1: + index = datatype_string.find("(") + if index == -1: # It is a primitive type. - left_bracket_index = len(datatype_string) - type_or_field = datatype_string[:left_bracket_index] - rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() - if type_or_field == "StringType": - return StringType() - elif type_or_field == "BinaryType": - return BinaryType() - elif type_or_field == "BooleanType": - return BooleanType() - elif type_or_field == "TimestampType": - return TimestampType() - elif type_or_field == "DecimalType": - return DecimalType() - elif type_or_field == "DoubleType": - return DoubleType() - elif type_or_field == "FloatType": - return FloatType() - elif type_or_field == "ByteType": - return ByteType() - elif type_or_field == "IntegerType": - return IntegerType() - elif type_or_field == "LongType": - return LongType() - elif type_or_field == "ShortType": - return ShortType() + index = len(datatype_string) + type_or_field = datatype_string[:index] + rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() + + if type_or_field in _all_primitive_types: + return _all_primitive_types[type_or_field]() + elif type_or_field == "ArrayType": last_comma_index = rest_part.rfind(",") containsNull = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": containsNull = False - elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + elementType = _parse_datatype_string( + rest_part[:last_comma_index].strip()) return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": last_comma_index = rest_part.rfind(",") valueContainsNull = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": valueContainsNull = False - keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) + keyType, valueType = _parse_datatype_list( + rest_part[:last_comma_index].strip()) return MapType(keyType, valueType, valueContainsNull) + elif type_or_field == "StructField": first_comma_index = rest_part.find(",") name = rest_part[:first_comma_index].strip() last_comma_index = rest_part.rfind(",") nullable = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": nullable = False dataType = _parse_datatype_string( - rest_part[first_comma_index+1:last_comma_index].strip()) + rest_part[first_comma_index + 1:last_comma_index].strip()) return StructField(name, dataType, nullable) + elif type_or_field == "StructType": # rest_part should be in the format like # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(")+1:-1] + field_list_string = rest_part[rest_part.find("(") + 1:-1] fields = _parse_datatype_list(field_list_string) return StructType(fields) +# Mapping Python types to Spark SQL DateType +_type_mappings = { + bool: BooleanType, + int: IntegerType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + decimal.Decimal: DecimalType, + datetime.datetime: TimestampType, + datetime.date: TimestampType, + datetime.time: TimestampType, +} + + +def _infer_type(obj): + """Infer the DataType from obj""" + if obj is None: + raise ValueError("Can not infer type for None") + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + if not obj: + raise ValueError("Can not infer type for empty dict") + key, value = obj.iteritems().next() + return MapType(_infer_type(key), _infer_type(value), True) + elif isinstance(obj, (list, array)): + if not obj: + raise ValueError("Can not infer type for empty list/array") + return ArrayType(_infer_type(obj[0]), True) + else: + try: + return _infer_schema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, tuple): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in row): + items = row + else: + raise ValueError("Can't infer schema from tuple") + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _create_converter(obj, dataType): + """Create an converter to drop the names of fields in obj """ + if not _has_struct(dataType): + return lambda x: x + + elif isinstance(dataType, ArrayType): + conv = _create_converter(obj[0], dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + value = obj.values()[0] + conv = _create_converter(value, dataType.valueType) + return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + + # dataType must be StructType + names = [f.name for f in dataType.fields] + + if isinstance(obj, dict): + conv = lambda o: tuple(o.get(n) for n in names) + + elif isinstance(obj, tuple): + if hasattr(obj, "_fields"): # namedtuple + conv = tuple + elif hasattr(obj, "__FIELDS__"): + conv = tuple + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + conv = lambda o: tuple(v for k, v in o) + else: + raise ValueError("unexpected tuple") + + elif hasattr(obj, "__dict__"): # object + conv = lambda o: [o.__dict__.get(n, None) for n in names] + + nested = any(_has_struct(f.dataType) for f in dataType.fields) + if not nested: + return conv + + row = conv(obj) + convs = [_create_converter(v, f.dataType) + for v, f in zip(row, dataType.fields)] + + def nested_conv(row): + return tuple(f(v) for f, v in zip(convs, conv(row))) + + return nested_conv + + +def _drop_schema(rows, schema): + """ all the names of fields, becoming tuples""" + iterator = iter(rows) + row = iterator.next() + converter = _create_converter(row, schema) + yield converter(row) + for i in iterator: + yield converter(i) + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,None,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,None,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(None,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(None,ArrayType(None,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, None, True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,None,true))),true) + """ + s = s.strip() + if not s: + return + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(None, _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types infered from obj + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> _infer_schema_type(row, schema) + StructType...IntegerType...DoubleType...StringType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + """ + if dataType is None: + return _infer_type(obj) + + if not obj: + raise ValueError("Can not infer type from empty value") + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_type(k), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + TimestampType: (datetime.datetime, datetime.time, datetime.date), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, IntegerType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + _type = type(dataType) + if _type not in _acceptable_types: + return + + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept abject in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with" + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + +_cached_cls = {} + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in mose cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct(dt): + """Return whether `dt` is or has StructType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct(dt.valueType) + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + class List(list): + + def __getitem__(self, i): + # create object with datetype + return _create_object(cls, list.__getitem__(self, i)) + + def __repr__(self): + # call collect __repr__ for nested objects + return "[%s]" % (", ".join(repr(self[i]) + for i in range(len(self)))) + + def __reduce__(self): + return list.__reduce__(self) + + return List + + elif isinstance(dataType, MapType): + vcls = _create_cls(dataType.valueType) + + class Dict(dict): + + def __getitem__(self, k): + # create object with datetype + return _create_object(vcls, dict.__getitem__(self, k)) + + def __repr__(self): + # call collect __repr__ for nested objects + return "{%s}" % (", ".join("%r: %r" % (k, self[k]) + for k in self)) + + def __reduce__(self): + return dict.__reduce__(self) + + return Dict + + elif not isinstance(dataType, StructType): + raise Exception("unexpected data type: %s" % dataType) + + class Row(tuple): + """ Row in SchemaRDD """ + __DATATYPE__ = dataType + __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__FIELDS__)) + + def __reduce__(self): + return (_restore_object, (self.__DATATYPE__, tuple(self))) + + return Row + + class SQLContext: """Main entry point for SparkSQL functionality. @@ -485,7 +892,7 @@ def __init__(self, sparkContext, sqlContext=None): >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + TypeError:... >>> bad_rdd = sc.parallelize([1,2,3]) >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -494,18 +901,22 @@ def __init__(self, sparkContext, sqlContext=None): ValueError:... >>> from datetime import datetime - >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, - ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, - ... "list": [1, 2, 3]}]) - >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean, x.time, x.dict["a"], x.list)) - >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> srdd = sqlCtx.inferSchema(allTypes) + >>> srdd.registerAsTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap + self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray if sqlContext: self._scala_SQLContext = sqlContext @@ -522,71 +933,123 @@ def _ssql_ctx(self): return self._scala_SQLContext def inferSchema(self, rdd): - """Infer and apply a schema to an RDD of L{dict}s. + """Infer and apply a schema to an RDD of L{Row}s. + + We peek at the first row of the RDD to determine the fields' names + and types. Nested collections are supported, which include array, + dict, list, Row, tuple, namedtuple, or object. - We peek at the first row of the RDD to determine the fields names - and types, and then use that to extract all the dictionaries. Nested - collections are supported, which include array, dict, list, set, and - tuple. + All the rows in `rdd` should have the same type with the first one, + or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects, + using dict is deprecated. + + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, - ... {"field1" : 3, "field2": "row3"}] - True + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') - >>> from array import array + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}}, - ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}] - True + >>> srdd.collect() + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}] - True + >>> srdd.collect() + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] """ - if (rdd.__class__ is SchemaRDD): - raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) - elif not isinstance(rdd.first(), dict): - raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % - (SchemaRDD.__name__, rdd.first())) - jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) - return SchemaRDD(srdd, self) + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") + + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated") + + schema = _infer_schema(first) + rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): - """Applies the given schema to the given RDD of L{dict}s. + """ + Applies the given schema to the given RDD of L{tuple} or L{list}s. + + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + + The schema should be a StructType. + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> srdd = sqlCtx.applySchema(rdd2, schema) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, - ... {"field1" : 3, "field2": "row3"}] - True + >>> srdd2.collect() + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] + >>> from datetime import datetime - >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, - ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, - ... "list": [1, 2, 3]}]) + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ ... StructField("byte", ByteType(), False), ... StructField("short", ShortType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), - ... StructField("map", MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: ( - ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + ... lambda x: (x.byte, x.short, x.float, x.time, + ... x.map["a"], x.struct.b, x.list, x.null)) >>> srdd.collect()[0] - (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, (2,), [1, 2, 3])]) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> srdd = sqlCtx.applySchema(rdd, typedSchema) + >>> srdd.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) + + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) + jrdd = self._pythonToJava(rdd._jrdd, batched) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): @@ -620,10 +1083,15 @@ def parquetFile(self, path): return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None): - """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + """ + Loads a text file storing one JSON object per line as a + L{SchemaRDD}. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine the schema. + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it goes through the entire dataset once to determine + the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -635,94 +1103,100 @@ def jsonFile(self, path, schema=None): >>> srdd1 = sqlCtx.jsonFile(jsonFile) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", ... StructType([ - ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") - >>> srdd6.collect() == [ - ... {"f1": "row1", "f2": None, "f3": None}, - ... {"f1": None, "f2": [10, 11], "f3": 10}, - ... {"f1": "row3", "f2": [], "f3": None}] - True + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> srdd6.collect() + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: jschema_rdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine the schema. + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it goes through the entire dataset once to determine + the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", ... StructType([ - ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) >>> srdd5 = sqlCtx.jsonRDD(json, schema) >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") - >>> srdd6.collect() == [ - ... {"f1": "row1", "f2": None, "f3": None}, - ... {"f1": None, "f2": [10, 11], "f3": 10}, - ... {"f1": "row3", "f2": [], "f3": None}] - True + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") + >>> srdd6.collect() + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] """ - def func(split, iterator): + + def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(rdd, func) + keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) @@ -732,9 +1206,8 @@ def sql(self, sqlQuery): >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, - ... {"f1" : 3, "f2": "row3"}] - True + >>> srdd2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) @@ -772,7 +1245,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " "sbt/sbt assembly", e) def _get_hive_ctx(self): @@ -780,13 +1254,15 @@ def _get_hive_ctx(self): def hiveql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + Runs a query expressed in HiveQL, returning the result as + a L{SchemaRDD}. """ return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) def hql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + Runs a query expressed in HiveQL, returning the result as + a L{SchemaRDD}. """ return self.hiveql(hqlQuery) @@ -803,10 +1279,14 @@ class LocalHiveContext(HiveContext): ... supress = hiveCtx.hql("DROP TABLE src") ... except Exception: ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + >>> kv1 = os.path.join(os.environ["SPARK_HOME"], + ... 'examples/src/main/resources/kv1.txt') + >>> supress = hiveCtx.hql( + ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + ... % kv1) + >>> results = hiveCtx.hql("FROM src SELECT value" + ... ).map(lambda r: int(r.value.split('_')[1])) >>> num = results.count() >>> reduce_sum = results.reduce(lambda x, y: x + y) >>> num @@ -816,8 +1296,9 @@ class LocalHiveContext(HiveContext): """ def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. Use HiveContext instead.", DeprecationWarning) + HiveContext.__init__(self, sparkContext, sqlContext) + warnings.warn("LocalHiveContext is deprecated. " + "Use HiveContext instead.", DeprecationWarning) def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -829,25 +1310,83 @@ def _get_hive_ctx(self): return self._jvm.TestHiveContext(self._jsc.sc()) -# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples -# are custom classes that must be generated per Schema. -class Row(dict): - """A row in L{SchemaRDD}. +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + +class Row(tuple): + """ + A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) - An extended L{dict} that takes a L{dict} in its constructor, and - exposes those items as fields. + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as - >>> r = Row({"hello" : "world", "foo" : "bar"}) - >>> r.hello - 'world' - >>> r.foo - 'bar' + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) """ - def __init__(self, d): - d.update(self.__dict__) - self.__dict__ = d - dict.__init__(self, d) + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) class SchemaRDD(RDD): @@ -861,6 +1400,10 @@ class SchemaRDD(RDD): implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. + + This class receives raw tuples from Java but assigns a class to it in + all its data-collection methods (mapPartitionsWithIndex, collect, take, + etc) so that PySpark sees them as Row objects with named fields. """ def __init__(self, jschema_rdd, sql_ctx): @@ -871,7 +1414,8 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_cached = False self.is_checkpointed = False self.ctx = self.sql_ctx._sc - self._jrdd_deserializer = self.ctx.serializer + # the _jrdd is created by javaToPython(), serialized by pickle + self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -881,7 +1425,7 @@ def _jrdd(self): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._toPython()._jrdd + self._lazy_jrdd = self._jschema_rdd.javaToPython() return self._lazy_jrdd @property @@ -931,7 +1475,8 @@ def saveAsTable(self, tableName): self._jschema_rdd.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" + """Returns the schema of this SchemaRDD (represented by + a L{StructType}).""" return _parse_datatype_string(self._jschema_rdd.schema().toString()) def schemaString(self): @@ -957,19 +1502,45 @@ def count(self): """ return self._jschema_rdd.count() - def _toPython(self): - # We have to import the Row class explicitly, so that the reference Pickler has is - # pyspark.sql.Row instead of __main__.Row - from pyspark.sql import Row - jrdd = self._jschema_rdd.javaToPython() - # TODO: This is inefficient, we should construct the Python Row object - # in Java land in the javaToPython function. May require a custom - # pickle serializer in Pyrolite - return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) - - # We override the default cache/persist/checkpoint behavior as we want to cache the underlying - # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class + def collect(self): + """ + Return a list that contains all of the rows in this RDD. + + Each object in the list is on Row, the fields can be accessed as + attributes. + """ + rows = RDD.collect(self) + cls = _create_cls(self.schema()) + return map(cls, rows) + + # Convert each object in the RDD to a Row with the right class + # for this SchemaRDD, so that fields can be accessed as attributes. + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithIndex(f).sum() + 6 + """ + rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) + + schema = self.schema() + import pickle + pickle.loads(pickle.dumps(schema)) + + def applySchema(_, it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) + return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + + # We override the default cache/persist/checkpoint behavior + # as we want to cache the underlying SchemaRDD object in the JVM, + # not the PythonRDD checkpointed by the super class def cache(self): self.is_cached = True self._jschema_rdd.cache() @@ -1024,7 +1595,8 @@ def subtract(self, other, numPartitions=None): if numPartitions is None: rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) + rdd = self._jschema_rdd.subtract(other._jschema_rdd, + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -1034,31 +1606,31 @@ def _test(): import doctest from array import array from pyspark.context import SparkContext - globs = globals().copy() + # let doctest run in pyspark.sql, so DataTypes can be picklable + import pyspark.sql + from pyspark.sql import Row, SQLContext + globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( - [{"field1": 1, "field2": "row1"}, - {"field1": 2, "field2": "row2"}, - {"field1": 3, "field2": "row3"}] + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] ) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) - globs['nestedRdd1'] = sc.parallelize([ - {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, - {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) - globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, - {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod( + pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 86338752a21c1..dad71079c29b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -411,35 +411,6 @@ class SQLContext(@transient val sparkContext: SparkContext) """.stripMargin.trim } - /** - * Peek at the first row of the RDD and infer its schema. - * It is only used by PySpark. - */ - private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - import scala.collection.JavaConversions._ - - def typeOfComplexValue: PartialFunction[Any, DataType] = { - case c: java.util.Calendar => TimestampType - case c: java.util.List[_] => - ArrayType(typeOfObject(c.head)) - case c: java.util.Map[_, _] => - val (key, value) = c.head - MapType(typeOfObject(key), typeOfObject(value)) - case c if c.getClass.isArray => - val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeOfObject(elem)) - case c => throw new Exception(s"Object of type $c cannot be used") - } - def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue - - val firstRow = rdd.first() - val fields = firstRow.map { - case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) - }.toSeq - - applySchemaToPythonRDD(rdd, StructType(fields)) - } - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. @@ -454,7 +425,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. */ private[sql] def applySchemaToPythonRDD( - rdd: RDD[Map[String, _]], + rdd: RDD[Array[Any]], schemaString: String): SchemaRDD = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) @@ -464,10 +435,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * Apply a schema defined by the schema to an RDD. It is only used by PySpark. */ private[sql] def applySchemaToPythonRDD( - rdd: RDD[Map[String, _]], + rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { - // TODO: We should have a better implementation once we do not turn a Python side record - // to a Map. import scala.collection.JavaConversions._ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} @@ -494,55 +463,39 @@ class SQLContext(@transient val sparkContext: SparkContext) val converted = c.map { e => convert(e, elementType)} JListWrapper(converted) - case (c: java.util.Map[_, _], struct: StructType) => - val row = new GenericMutableRow(struct.fields.length) - struct.fields.zipWithIndex.foreach { - case (field, i) => - val value = convert(c.get(field.name), field.dataType) - row.update(i, value) - } - row - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val converted = c.map { - case (key, value) => - (convert(key, keyType), convert(value, valueType)) - } - JMapWrapper(converted) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) - converted: Seq[Any] + c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (convert(key, keyType), convert(value, valueType)) + }.toMap + + case (c, StructType(fields)) if c.getClass.isArray => + new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => convert(e, f.dataType) + }): Row + + case (c: java.util.Calendar, TimestampType) => + new java.sql.Timestamp(c.getTime().getTime()) - case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort case (c: Double, FloatType) => c.toFloat + case (c, StringType) if !c.isInstanceOf[String] => c.toString case (c, _) => c } val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { - rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + rdd.map(m => m.zip(schema.fields).map { + case (value, field) => convert(value, field.dataType) + }) } else { rdd } val rowRdd = convertedRdd.mapPartitions { iter => - val row = new GenericMutableRow(schema.fields.length) - val fieldsWithIndex = schema.fields.zipWithIndex - iter.map { m => - // We cannot use m.values because the order of values returned by m.values may not - // match fields order. - fieldsWithIndex.foreach { - case (field, i) => - val value = - m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull - row.update(i, value) - } - - row: Row - } + iter.map { m => new GenericRow(m): Row} } new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 420f21fb9c1ae..d34f62dc8865e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -383,7 +383,7 @@ class SchemaRDD( import scala.collection.Map def toJava(obj: Any, dataType: DataType): Any = dataType match { - case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) + case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) case array: ArrayType => obj match { case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava @@ -397,21 +397,19 @@ class SchemaRDD( // Pyrolite can handle Timestamp case other => obj } - def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { - val fields = structType.fields.map(field => (field.name, field.dataType)) - val map: JMap[String, Any] = new java.util.HashMap - row.zip(fields).foreach { - case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType)) - } - map + def rowToArray(row: Row, structType: StructType): Array[Any] = { + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray } val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToMap(row, rowSchema) - }.grouped(10).map(batched => pickle.dumps(batched.toArray)) + rowToArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } From 3822f33f3ce1428703a4796d7a119b40a6b32259 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 1 Aug 2014 18:52:01 -0700 Subject: [PATCH 099/170] [SPARK-2212][SQL] Hash Outer Join (follow-up bug fix). We need to carefully set the ouputPartitioning of the HashOuterJoin Operator. Otherwise, we may not correctly handle nulls. Author: Yin Huai Closes #1721 from yhuai/SPARK-2212-BugFix and squashes the following commits: ed5eef7 [Yin Huai] Correctly choosing outputPartitioning for the HashOuterJoin operator. --- .../apache/spark/sql/execution/joins.scala | 9 +- .../org/apache/spark/sql/JoinSuite.scala | 99 +++++++++++++++++++ .../scala/org/apache/spark/sql/TestData.scala | 8 ++ 3 files changed, 114 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 82f0a74b630bf..cc138c749949d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -158,7 +158,12 @@ case class HashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -309,7 +314,7 @@ case class HashOuterJoin( leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) } - case x => throw new Exception(s"Need to add implementation for $x") + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 037890682f7b1..2fc80588182d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -197,6 +197,31 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) + + // Make sure we are choosing left.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + (null, 6) :: Nil) } test("right outer join") { @@ -232,6 +257,31 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "d", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + // Make sure we are choosing right.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + (null, 6) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: Nil) } test("full outer join") { @@ -269,5 +319,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) + + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + (null, 10) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: + (null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + (1, 1) :: + (2, 1) :: + (3, 1) :: + (4, 1) :: + (5, 1) :: + (6, 1) :: + (null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + (null, 10) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 213190e812026..58cee21e8ad4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -118,6 +118,14 @@ object TestData { ) nullInts.registerAsTable("nullInts") + val allNulls = + TestSQLContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil) + allNulls.registerAsTable("allNulls") + case class NullStrings(n: Int, s: String) val nullStrings = TestSQLContext.sparkContext.parallelize( From 0da07da53e5466ec44c8050020cbc4b9957cb949 Mon Sep 17 00:00:00 2001 From: Albert Chu Date: Fri, 1 Aug 2014 19:00:38 -0700 Subject: [PATCH 100/170] [SPARK-2116] Load spark-defaults.conf from SPARK_CONF_DIR if set If SPARK_CONF_DIR environment variable is set, search it for spark-defaults.conf. Author: Albert Chu Closes #1059 from chu11/SPARK-2116 and squashes the following commits: 9f3ac94 [Albert Chu] SPARK-2116: If SPARK_CONF_DIR environment variable is set, search it for spark-defaults.conf. --- .../apache/spark/deploy/SparkSubmitArguments.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index dd044e6298760..9391f24e71ed7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -85,6 +85,17 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { */ private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user + if (propertiesFile == null) { + sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir => + val sep = File.separator + val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" + val file = new File(defaultPath) + if (file.exists()) { + propertiesFile = file.getAbsolutePath + } + } + } + if (propertiesFile == null) { sys.env.get("SPARK_HOME").foreach { sparkHome => val sep = File.separator From a38d3c9efcc0386b52ac4f041920985ae7300e28 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 1 Aug 2014 19:35:16 -0700 Subject: [PATCH 101/170] [SPARK-2800]: Exclude scalastyle-output.xml Apache RAT checks Author: GuoQiang Li Closes #1729 from witgo/SPARK-2800 and squashes the following commits: 13ca966 [GuoQiang Li] Add scalastyle-output.xml to .rat-excludes file --- .rat-excludes | 1 + 1 file changed, 1 insertion(+) diff --git a/.rat-excludes b/.rat-excludes index 372bc2587ccc3..bccb043c2bb55 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -55,3 +55,4 @@ dist/* .*ipr .*iws logs +.*scalastyle-output.xml From e8e0fd691a06a2887fdcffb2217b96805ace0cb0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Aug 2014 19:38:21 -0700 Subject: [PATCH 102/170] [SPARK-2764] Simplify daemon.py process structure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Curently, daemon.py forks a pool of numProcessors subprocesses, and those processes fork themselves again to create the actual Python worker processes that handle data. I think that this extra layer of indirection is unnecessary and adds a lot of complexity. This commit attempts to remove this middle layer of subprocesses by launching the workers directly from daemon.py. See https://github.com/mesos/spark/pull/563 for the original PR that added daemon.py, where I raise some issues with the current design. Author: Josh Rosen Closes #1680 from JoshRosen/pyspark-daemon and squashes the following commits: 5abbcb9 [Josh Rosen] Replace magic number: 4 -> EINTR 5495dff [Josh Rosen] Throw IllegalStateException if worker launch fails. b79254d [Josh Rosen] Detect failed fork() calls; improve error logging. 282c2c4 [Josh Rosen] Remove daemon.py exit logging, since it caused problems: 8554536 [Josh Rosen] Fix daemon’s shutdown(); log shutdown reason. 4e0fab8 [Josh Rosen] Remove shared-memory exit_flag; don't die on worker death. e9892b4 [Josh Rosen] [WIP] [SPARK-2764] Simplify daemon.py process structure. --- .../api/python/PythonWorkerFactory.scala | 10 +- python/pyspark/daemon.py | 179 +++++++----------- 2 files changed, 79 insertions(+), 110 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 759cbe2c46c52..15fe8a9be6bfe 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -64,10 +64,16 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Attempt to connect, restart and retry once if it fails try { - new Socket(daemonHost, daemonPort) + val socket = new Socket(daemonHost, daemonPort) + val launchStatus = new DataInputStream(socket.getInputStream).readInt() + if (launchStatus != 0) { + throw new IllegalStateException("Python daemon failed to launch worker") + } + socket } catch { case exc: SocketException => - logWarning("Python daemon unexpectedly quit, attempting to restart") + logWarning("Failed to open socket to Python daemon:", exc) + logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() new Socket(daemonHost, daemonPort) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 8a5873ded2b8b..9fde0dde0f4b4 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -15,64 +15,39 @@ # limitations under the License. # +import numbers import os import signal +import select import socket import sys import traceback -import multiprocessing -from ctypes import c_bool from errno import EINTR, ECHILD from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main from pyspark.serializers import write_int -try: - POOLSIZE = multiprocessing.cpu_count() -except NotImplementedError: - POOLSIZE = 4 - -exit_flag = multiprocessing.Value(c_bool, False) - - -def should_exit(): - global exit_flag - return exit_flag.value - def compute_real_exit_code(exit_code): # SystemExit's code can be integer or string, but os._exit only accepts integers - import numbers if isinstance(exit_code, numbers.Integral): return exit_code else: return 1 -def worker(listen_sock): +def worker(sock): + """ + Called by a worker process after the fork(). + """ # Redirect stdout to stderr os.dup2(2, 1) sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - # Manager sends SIGHUP to request termination of workers in the pool - def handle_sighup(*args): - assert should_exit() - signal.signal(SIGHUP, handle_sighup) - - # Cleanup zombie children - def handle_sigchld(*args): - pid = status = None - try: - while (pid, status) != (0, 0): - pid, status = os.waitpid(0, os.WNOHANG) - except EnvironmentError as err: - if err.errno == EINTR: - # retry - handle_sigchld() - elif err.errno != ECHILD: - raise - signal.signal(SIGCHLD, handle_sigchld) + signal.signal(SIGHUP, SIG_DFL) + signal.signal(SIGCHLD, SIG_DFL) + signal.signal(SIGTERM, SIG_DFL) # Blocks until the socket is closed by draining the input stream # until it raises an exception or returns EOF. @@ -85,55 +60,23 @@ def waitSocketClose(sock): except: pass - # Handle clients - while not should_exit(): - # Wait until a client arrives or we have to exit - sock = None - while not should_exit() and sock is None: - try: - sock, addr = listen_sock.accept() - except EnvironmentError as err: - if err.errno != EINTR: - raise - - if sock is not None: - # Fork a child to handle the client. - # The client is handled in the child so that the manager - # never receives SIGCHLD unless a worker crashes. - if os.fork() == 0: - # Leave the worker pool - signal.signal(SIGHUP, SIG_DFL) - signal.signal(SIGCHLD, SIG_DFL) - listen_sock.close() - # Read the socket using fdopen instead of socket.makefile() because the latter - # seems to be very slow; note that we need to dup() the file descriptor because - # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - exit_code = 0 - try: - worker_main(infile, outfile) - except SystemExit as exc: - exit_code = exc.code - finally: - outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) - else: - sock.close() - - -def launch_worker(listen_sock): - if os.fork() == 0: - try: - worker(listen_sock) - except Exception as err: - traceback.print_exc() - os._exit(1) - else: - assert should_exit() - os._exit(0) + # Read the socket using fdopen instead of socket.makefile() because the latter + # seems to be very slow; note that we need to dup() the file descriptor because + # otherwise writes also cause a seek that makes us miss data on the read side. + infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + exit_code = 0 + try: + write_int(0, outfile) # Acknowledge that the fork was successful + outfile.flush() + worker_main(infile, outfile) + except SystemExit as exc: + exit_code = exc.code + finally: + outfile.flush() + # The Scala side will close the socket upon task completion. + waitSocketClose(sock) + os._exit(compute_real_exit_code(exit_code)) def manager(): @@ -143,29 +86,28 @@ def manager(): # Create a listening socket on the AF_INET loopback interface listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) - listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) - # Launch initial worker pool - for idx in range(POOLSIZE): - launch_worker(listen_sock) - listen_sock.close() - - def shutdown(): - global exit_flag - exit_flag.value = True + def shutdown(code): + signal.signal(SIGTERM, SIG_DFL) + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) + exit(code) - # Gracefully exit on SIGTERM, don't die on SIGHUP - signal.signal(SIGTERM, lambda signum, frame: shutdown()) - signal.signal(SIGHUP, SIG_IGN) + def handle_sigterm(*args): + shutdown(1) + signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM + signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP # Cleanup zombie children def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) - if status != 0 and not should_exit(): - raise RuntimeError("worker crashed: %s, %s" % (pid, status)) + if status != 0: + msg = "worker %s crashed abruptly with exit status %s" % (pid, status) + print >> sys.stderr, msg except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -174,20 +116,41 @@ def handle_sigchld(*args): # Initialization complete sys.stdout.close() try: - while not should_exit(): + while True: try: - # Spark tells us to exit by closing stdin - if os.read(0, 512) == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() + ready_fds = select.select([0, listen_sock], [], [])[0] + except select.error as ex: + if ex[0] == EINTR: + continue + else: raise + if 0 in ready_fds: + # Spark told us to exit by closing stdin + shutdown(0) + if listen_sock in ready_fds: + sock, addr = listen_sock.accept() + # Launch a worker process + try: + fork_return_code = os.fork() + if fork_return_code == 0: + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: + sock.close() + except OSError as e: + print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + write_int(-1, outfile) # Signal that the fork failed + outfile.flush() + sock.close() finally: - signal.signal(SIGTERM, SIG_DFL) - exit_flag.value = True - # Send SIGHUP to notify workers of shutdown - os.kill(0, SIGHUP) + shutdown(1) if __name__ == '__main__': From f6a1899306c5ad766fea122d3ab4b83436d9f6fd Mon Sep 17 00:00:00 2001 From: Jeremy Freeman Date: Fri, 1 Aug 2014 20:10:26 -0700 Subject: [PATCH 103/170] Streaming mllib [SPARK-2438][MLLIB] This PR implements a streaming linear regression analysis, in which a linear regression model is trained online as new data arrive. The design is based on discussions with tdas and mengxr, in which we determined how to add this functionality in a general way, with minimal changes to existing libraries. __Summary of additions:__ _StreamingLinearAlgorithm_ - An abstract class for fitting generalized linear models online to streaming data, including training on (and updating) a model, and making predictions. _StreamingLinearRegressionWithSGD_ - Class and companion object for running streaming linear regression _StreamingLinearRegressionTestSuite_ - Unit tests _StreamingLinearRegression_ - Example use case: fitting a model online to data from one stream, and making predictions on other data __Notes__ - If this looks good, I can use the StreamingLinearAlgorithm class to easily implement other analyses that follow the same logic (Ridge, Lasso, Logistic, SVM). Author: Jeremy Freeman Author: freeman Closes #1361 from freeman-lab/streaming-mllib and squashes the following commits: 775ea29 [Jeremy Freeman] Throw error if user doesn't initialize weights 4086fee [Jeremy Freeman] Fixed current weight formatting 8b95b27 [Jeremy Freeman] Restored broadcasting 29f27ec [Jeremy Freeman] Formatting 8711c41 [Jeremy Freeman] Used return to avoid indentation 777b596 [Jeremy Freeman] Restored treeAggregate 74cf440 [Jeremy Freeman] Removed static methods d28cf9a [Jeremy Freeman] Added usage notes c3326e7 [Jeremy Freeman] Improved documentation 9541a41 [Jeremy Freeman] Merge remote-tracking branch 'upstream/master' into streaming-mllib 66eba5e [Jeremy Freeman] Fixed line lengths 2fe0720 [Jeremy Freeman] Minor cleanup 7d51378 [Jeremy Freeman] Moved streaming loader to MLUtils b9b69f6 [Jeremy Freeman] Added setter methods c3f8b5a [Jeremy Freeman] Modified logging 00aafdc [Jeremy Freeman] Add modifiers 14b801e [Jeremy Freeman] Name changes c7d38a3 [Jeremy Freeman] Move check for empty data to GradientDescent 4b0a5d3 [Jeremy Freeman] Cleaned up tests 74188d6 [Jeremy Freeman] Eliminate dependency on commons 50dd237 [Jeremy Freeman] Removed experimental tag 6bfe1e6 [Jeremy Freeman] Fixed imports a2a63ad [freeman] Makes convergence test more robust 86220bc [freeman] Streaming linear regression unit tests fb4683a [freeman] Minor changes for scalastyle consistency fd31e03 [freeman] Changed logging behavior 453974e [freeman] Fixed indentation c4b1143 [freeman] Streaming linear regression 604f4d7 [freeman] Expanded private class to include mllib d99aa85 [freeman] Helper methods for streaming MLlib apps 0898add [freeman] Added dependency on streaming --- .../mllib/StreamingLinearRegression.scala | 73 ++++++++++ mllib/pom.xml | 5 + .../mllib/optimization/GradientDescent.scala | 9 ++ .../mllib/regression/LinearRegression.scala | 4 +- .../regression/StreamingLinearAlgorithm.scala | 106 ++++++++++++++ .../StreamingLinearRegressionWithSGD.scala | 88 ++++++++++++ .../org/apache/spark/mllib/util/MLUtils.scala | 15 ++ .../StreamingLinearRegressionSuite.scala | 135 ++++++++++++++++++ 8 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala new file mode 100644 index 0000000000000..1fd37edfa7427 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Train a linear regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features. n must be the same for train and test. + * + * Usage: StreamingLinearRegression + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2 + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ +object StreamingLinearRegression { + + def main(args: Array[String]) { + + if (args.length != 4) { + System.err.println( + "Usage: StreamingLinearRegression ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) + val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) + + model.trainOn(trainingData) + model.predictOn(testData).print() + + ssc.start() + ssc.awaitTermination() + + } + +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 45046eca5b18c..9a33bd1cf6ad1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,6 +40,11 @@ spark-core_${scala.binary.version} ${project.version}
+ + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 356aa949afcf5..a6912056395d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -162,6 +162,14 @@ object GradientDescent extends Logging { val numExamples = data.count() val miniBatchSize = numExamples * miniBatchFraction + // if no data, return initial weights to avoid NaNs + if (numExamples == 0) { + + logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + return (initialWeights, stochasticLossHistory.toArray) + + } + // Initialize weights as a column vector var weights = Vectors.dense(initialWeights.toArray) val n = weights.size @@ -202,5 +210,6 @@ object GradientDescent extends Logging { stochasticLossHistory.takeRight(10).mkString(", "))) (weights, stochasticLossHistory.toArray) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8c078ec9f66e9..81b6598377ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -49,7 +49,7 @@ class LinearRegressionModel ( * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ -class LinearRegressionWithSGD private ( +class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var miniBatchFraction: Double) @@ -68,7 +68,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala new file mode 100644 index 0000000000000..b8b0b42611775 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream + +/** + * :: DeveloperApi :: + * StreamingLinearAlgorithm implements methods for continuously + * training a generalized linear model model on streaming data, + * and using it for prediction on (possibly different) streaming data. + * + * This class takes as type parameters a GeneralizedLinearModel, + * and a GeneralizedLinearAlgorithm, making it easy to extend to construct + * streaming versions of any analyses using GLMs. + * Initial weights must be set before calling trainOn or predictOn. + * Only weights will be updated, not an intercept. If the model needs + * an intercept, it should be manually appended to the input data. + * + * For example usage, see `StreamingLinearRegressionWithSGD`. + * + * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * are called in an application will affect the results. When called on + * the same DStream, if trainOn is called before predictOn, when new data + * arrive the model will update and the prediction will be based on the new + * model. Whereas if predictOn is called first, the prediction will use the model + * from the previous update. + * + * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * will generate predictions for each one all using the current model. + * It is also ok to call trainOn on different streams; this will update + * the model using each of the different sources, in sequence. + * + */ +@DeveloperApi +abstract class StreamingLinearAlgorithm[ + M <: GeneralizedLinearModel, + A <: GeneralizedLinearAlgorithm[M]] extends Logging { + + /** The model to be updated and used for prediction. */ + protected var model: M + + /** The algorithm to use for updating. */ + protected val algorithm: A + + /** Return the latest model. */ + def latestModel(): M = { + model + } + + /** + * Update the model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * and updates the model based on every subsequent + * batch of data from the stream. + * + * @param data DStream containing labeled data + */ + def trainOn(data: DStream[LabeledPoint]) { + if (Option(model.weights) == None) { + logError("Initial weights must be set before starting training") + throw new IllegalArgumentException + } + data.foreachRDD { (rdd, time) => + model = algorithm.run(rdd, model.weights) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.weights.size match { + case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) + } + } + + /** + * Use the model to make predictions on batches of data from a DStream + * + * @param data DStream containing labeled data + * @return DStream containing predictions + */ + def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + if (Option(model.weights) == None) { + logError("Initial weights must be set before starting prediction") + throw new IllegalArgumentException + } + data.map(x => model.predict(x.features)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala new file mode 100644 index 0000000000000..8851097050318 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Train or predict a linear regression model on streaming data. Training uses + * Stochastic Gradient Descent to update the model based on each new batch of + * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) + * + * Each batch of data is assumed to be an RDD of LabeledPoints. + * The number of data points per batch can vary, but the number + * of features must be constant. An initial weight + * vector must be provided. + * + * Use a builder pattern to construct a streaming linear regression + * analysis in an application, like: + * + * val model = new StreamingLinearRegressionWithSGD() + * .setStepSize(0.5) + * .setNumIterations(10) + * .setInitialWeights(Vectors.dense(...)) + * .trainOn(DStream) + * + */ +@Experimental +class StreamingLinearRegressionWithSGD ( + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double, + private var initialWeights: Vector) + extends StreamingLinearAlgorithm[ + LinearRegressionModel, LinearRegressionWithSGD] with Serializable { + + /** + * Construct a StreamingLinearRegression object with default parameters: + * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0}. + * Initial weights must be set before using trainOn or predictOn + * (see `StreamingLinearAlgorithm`) + */ + def this() = this(0.1, 50, 1.0, null) + + val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + + var model = algorithm.createModel(initialWeights, 0.0) + + /** Set the step size for gradient descent. Default: 0.1. */ + def setStepSize(stepSize: Double): this.type = { + this.algorithm.optimizer.setStepSize(stepSize) + this + } + + /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + def setNumIterations(numIterations: Int): this.type = { + this.algorithm.optimizer.setNumIterations(numIterations) + this + } + + /** Set the fraction of each batch to use for updates. Default: 1.0. */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) + this + } + + /** Set the initial weights. Default: [0.0, 0.0]. */ + def setInitialWeights(initialWeights: Vector): this.type = { + this.model = algorithm.createModel(initialWeights, 0.0) + this + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index dc10a194783ed..f4cce86a65ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -30,6 +30,8 @@ import org.apache.spark.util.random.BernoulliSampler import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -192,6 +194,19 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) + /** + * Loads streaming labeled points from a stream of text files + * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. + * See `StreamingContext.textFileStream` for more details on how to + * generate a stream from files + * + * @param ssc Streaming context + * @param dir Directory path in any Hadoop-supported file system URI + * @return Labeled points stored as a DStream[LabeledPoint] + */ + def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = + ssc.textFileStream(dir).map(LabeledPointParser.parse) + /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala new file mode 100644 index 0000000000000..ed21f84472c9a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import java.io.File +import java.nio.charset.Charset + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.Files +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.util.Utils + +class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { + + // Assert that two values are equal within tolerance epsilon + def assertEqual(v1: Double, v2: Double, epsilon: Double) { + def errorMessage = v1.toString + " did not equal " + v2.toString + assert(math.abs(v1-v2) <= epsilon, errorMessage) + } + + // Assert that model predictions are correct + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + } + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data + test("streaming linear regression parameter accuracy") { + + val testDir = Files.createTempDir() + val numBatches = 10 + val batchDuration = Milliseconds(1000) + val ssc = new StreamingContext(sc, batchDuration) + val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.1) + .setNumIterations(50) + + model.trainOn(data) + + ssc.start() + + // write data to a file stream + for (i <- 0 until numBatches) { + val samples = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) + val file = new File(testDir, i.toString) + Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) + Thread.sleep(batchDuration.milliseconds) + } + + ssc.stop(stopSparkContext=false) + + System.clearProperty("spark.driver.port") + Utils.deleteRecursively(testDir) + + // check accuracy of final parameter estimates + assertEqual(model.latestModel().intercept, 0.0, 0.1) + assertEqual(model.latestModel().weights(0), 10.0, 0.1) + assertEqual(model.latestModel().weights(1), 10.0, 0.1) + + // check accuracy of predictions + val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) + validatePrediction(validationData.map(row => model.latestModel().predict(row.features)), + validationData) + } + + // Test that parameter estimates improve when learning Y = 10*X1 on streaming data + test("streaming linear regression parameter convergence") { + + val testDir = Files.createTempDir() + val batchDuration = Milliseconds(2000) + val ssc = new StreamingContext(sc, batchDuration) + val numBatches = 5 + val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.1) + .setNumIterations(50) + + model.trainOn(data) + + ssc.start() + + // write data to a file stream + val history = new ArrayBuffer[Double](numBatches) + for (i <- 0 until numBatches) { + val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) + val file = new File(testDir, i.toString) + Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) + Thread.sleep(batchDuration.milliseconds) + // wait an extra few seconds to make sure the update finishes before new data arrive + Thread.sleep(4000) + history.append(math.abs(model.latestModel().weights(0) - 10.0)) + } + + ssc.stop(stopSparkContext=false) + + System.clearProperty("spark.driver.port") + Utils.deleteRecursively(testDir) + + val deltas = history.drop(1).zip(history.dropRight(1)) + // check error stability (it always either shrinks, or increases with small tol) + assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) + // check that error shrunk on at least 2 batches + assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) + + } + +} From c281189222e645d2c87277c269e2102c3c8ccc95 Mon Sep 17 00:00:00 2001 From: Michael Giannakopoulos Date: Fri, 1 Aug 2014 21:00:31 -0700 Subject: [PATCH 104/170] [SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods. Related to issue: [SPARK-2550](https://issues.apache.org/jira/browse/SPARK-2550?jql=project%20%3D%20SPARK%20AND%20resolution%20%3D%20Unresolved%20AND%20priority%20%3D%20Major%20ORDER%20BY%20key%20DESC). Author: Michael Giannakopoulos Closes #1624 from miccagiann/new-branch and squashes the following commits: c02e5f5 [Michael Giannakopoulos] Merge cleanly with upstream/master. 8dcb888 [Michael Giannakopoulos] Putting the if/else if statements in brackets. fed8eaa [Michael Giannakopoulos] Adding a space in the message related to the IllegalArgumentException. 44e6ff0 [Michael Giannakopoulos] Adding a blank line before python class LinearRegressionWithSGD. 8eba9c5 [Michael Giannakopoulos] Change function signatures. Exception is thrown from the scala component and not from the python one. 638be47 [Michael Giannakopoulos] Modified code to comply with code standards. ec50ee9 [Michael Giannakopoulos] Shorten the if-elif-else statement in regression.py file b962744 [Michael Giannakopoulos] Replaced the enum classes, with strings-keywords for defining the values of 'regType' parameter. 78853ec [Michael Giannakopoulos] Providing intercept and regualizer functionallity for linear methods in only one function. 3ac8874 [Michael Giannakopoulos] Added support for regularizer and intercection parameters for linear regression method. --- .../mllib/api/python/PythonMLLibAPI.scala | 28 ++++++++++++---- python/pyspark/mllib/regression.py | 32 ++++++++++++++++--- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 122925d096e98..7d912737b8f0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -23,6 +23,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ @@ -252,15 +254,27 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val lrAlg = new LinearRegressionWithSGD() + lrAlg.setIntercept(intercept) + lrAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + if (regType == "l2") { + lrAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + lrAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LinearRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + lrAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index b84bc531dec8c..041b119269427 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -112,12 +112,36 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a linear regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a linear regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) From e25ec06171e3ba95920cbfe9df3cd3d990f1a3a3 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Fri, 1 Aug 2014 21:25:02 -0700 Subject: [PATCH 105/170] [SPARK-1580][MLLIB] Estimate ALS communication and computation costs. Continue the work from #493. Closes #493 and Closes #593 Author: Tor Myklebust Author: Xiangrui Meng Closes #1731 from mengxr/tmyklebu-alscost and squashes the following commits: 9b56a8b [Xiangrui Meng] updated API and added a simple test 68a3229 [Xiangrui Meng] merge master 217bd1d [Tor Myklebust] Documentation and choleskies -> subproblems. 8cbb718 [Tor Myklebust] Braces get spaces. 0455cd4 [Tor Myklebust] Parens for collectAsMap. 2b2febe [Tor Myklebust] Use `makeLinkRDDs` when estimating costs. 2ab7a5d [Tor Myklebust] Reindent estimateCost's declaration and make it return Seqs. 8b21e6d [Tor Myklebust] Fix overlong lines. 8cbebf1 [Tor Myklebust] Rename and clean up the return format of cost estimator. 6615ed5 [Tor Myklebust] It's more useful to give per-partition estimates. Do that. 5530678 [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark into alscost 6c31324 [Tor Myklebust] Make it actually build... a1184d1 [Tor Myklebust] Mark ALS.evaluatePartitioner DeveloperApi. 657a71b [Tor Myklebust] Simple-minded estimates of computation and communication costs in ALS. dcf583a [Tor Myklebust] Remove the partitioner member variable; instead, thread that needle everywhere it needs to go. 23d6f91 [Tor Myklebust] Stop making the partitioner configurable. 495784f [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark 674933a [Tor Myklebust] Fix style. 40edc23 [Tor Myklebust] Fix missing space. f841345 [Tor Myklebust] Fix daft bug creating 'pairs', also for -> foreach. 5ec9e6c [Tor Myklebust] Clean a couple of things up using 'map'. 36a0f43 [Tor Myklebust] Make the partitioner private. d872b09 [Tor Myklebust] Add negative id ALS test. df27697 [Tor Myklebust] Support custom partitioners. Currently we use the same partitioner for users and products. c90b6d8 [Tor Myklebust] Scramble user and product ids before bucketing. c774d7d [Tor Myklebust] Make the partitioner a member variable and use it instead of modding directly. --- .../spark/mllib/recommendation/ALS.scala | 126 +++++++++++++++++- .../spark/mllib/recommendation/ALSSuite.scala | 26 +++- 2 files changed, 144 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 36d262fed425a..8ebc7e27ed4dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,7 +17,8 @@ package org.apache.spark.mllib.recommendation -import scala.collection.mutable.{ArrayBuffer, BitSet} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.math.{abs, sqrt} import scala.util.Random import scala.util.Sorting @@ -25,7 +26,7 @@ import scala.util.hashing.byteswap32 import org.jblas.{DoubleMatrix, SimpleBlas, Solve} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.{Logging, HashPartitioner, Partitioner} import org.apache.spark.storage.StorageLevel @@ -39,7 +40,8 @@ import org.apache.spark.mllib.optimization.NNLS * of the elements within this block, and the list of destination blocks that each user or * product will need to send its feature vector to. */ -private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) +private[recommendation] +case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet]) /** @@ -382,7 +384,7 @@ class ALS private ( val userIds = ratings.map(_.user).distinct.sorted val numUsers = userIds.length val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks)) + val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks)) for (r <- ratings) { shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true } @@ -797,4 +799,120 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } + + /** + * :: DeveloperApi :: + * Statistics of a block in ALS computation. + * + * @param category type of this block, "user" or "product" + * @param index index of this block + * @param count number of users or products inside this block, the same as the number of + * least-squares problems to solve on this block in each iteration + * @param numRatings total number of ratings inside this block, the same as the number of outer + * products we need to make on this block in each iteration + * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve + * before each iteration + * @param numOutLinks total number of outgoing links, the same as the number of vectors to send + * for the next iteration + */ + @DeveloperApi + case class BlockStats( + category: String, + index: Int, + count: Long, + numRatings: Long, + numInLinks: Long, + numOutLinks: Long) + + /** + * :: DeveloperApi :: + * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the + * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing + * load balance. + * + * @param ratings an RDD of ratings + * @param numUserBlocks number of user blocks + * @param numProductBlocks number of product blocks + * @return statistics of user blocks and product blocks + */ + @DeveloperApi + def analyzeBlocks( + ratings: RDD[Rating], + numUserBlocks: Int, + numProductBlocks: Int): Array[BlockStats] = { + + val userPartitioner = new ALSPartitioner(numUserBlocks) + val productPartitioner = new ALSPartitioner(numProductBlocks) + + val ratingsByUserBlock = ratings.map { rating => + (userPartitioner.getPartition(rating.user), rating) + } + val ratingsByProductBlock = ratings.map { rating => + (productPartitioner.getPartition(rating.product), + Rating(rating.product, rating.user, rating.rating)) + } + + val als = new ALS() + val (userIn, userOut) = + als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner) + val (prodIn, prodOut) = + als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner) + + def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = { + outLinks.map { x => + val grid = new mutable.HashMap[(Int, Int), Long]() + val uPartition = x._1 + x._2.shouldSend.foreach { ss => + ss.foreach { pPartition => + val pair = (uPartition, pPartition) + grid.put(pair, grid.getOrElse(pair, 0L) + 1L) + } + } + grid + }.reduce { (grid1, grid2) => + grid2.foreach { x => + grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2) + } + grid1 + }.toMap + } + + val userSendGrid = sendGrid(userOut) + val prodSendGrid = sendGrid(prodOut) + + val userInbound = new Array[Long](numUserBlocks) + val prodInbound = new Array[Long](numProductBlocks) + val userOutbound = new Array[Long](numUserBlocks) + val prodOutbound = new Array[Long](numProductBlocks) + + for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) { + userOutbound(u) += userSendGrid.getOrElse((u, p), 0L) + prodInbound(p) += userSendGrid.getOrElse((u, p), 0L) + userInbound(u) += prodSendGrid.getOrElse((p, u), 0L) + prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L) + } + + val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap() + val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap() + + val userRatings = countRatings(userIn) + val prodRatings = countRatings(prodIn) + + val userStats = Array.tabulate(numUserBlocks)( + u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u))) + val productStatus = Array.tabulate(numProductBlocks)( + p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p))) + + (userStats ++ productStatus).toArray + } + + private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = { + inLinks.mapValues { ilb => + var numRatings = 0L + ilb.ratingsForBlock.foreach { ar => + ar.foreach { p => numRatings += p._1.length } + } + numRatings + }.collectAsMap().toMap + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 81bebec8c7a39..017c39edb185f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -22,11 +22,11 @@ import scala.math.abs import scala.util.Random import org.scalatest.FunSuite - import org.jblas.DoubleMatrix -import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.recommendation.ALS.BlockStats object ALSSuite { @@ -67,8 +67,10 @@ object ALSSuite { case true => // Generate raw values from [0,9], or if negativeWeights, from [-2,7] val raw = new DoubleMatrix(users, products, - Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) - val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) + Array.fill(users * products)( + (if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) + val prefs = + new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) (raw, prefs) case false => (userMatrix.mmul(productMatrix), null) } @@ -160,6 +162,22 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } + test("analyze one user block and one product block") { + val localRatings = Seq( + Rating(0, 100, 1.0), + Rating(0, 101, 2.0), + Rating(0, 102, 3.0), + Rating(1, 102, 4.0), + Rating(2, 103, 5.0)) + val ratings = sc.makeRDD(localRatings, 2) + val stats = ALS.analyzeBlocks(ratings, 1, 1) + assert(stats.size === 2) + assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3)) + assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4)) + } + + // TODO: add tests for analyzing multiple user/product blocks + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * From fda475987f3b8b37d563033b0e45706ce433824a Mon Sep 17 00:00:00 2001 From: Burak Date: Fri, 1 Aug 2014 22:32:12 -0700 Subject: [PATCH 106/170] [SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator. RandomRDD is now of generic type The RandomRDDGenerators used to only output RDD[Double]. Now RandomRDDGenerators.randomRDD can be used to generate a random RDD[T] via a class that extends RandomDataGenerator, by supplying a type T and overriding the nextValue() function as they wish. Author: Burak Closes #1732 from brkyvz/SPARK-2801 and squashes the following commits: c94a694 [Burak] [SPARK-2801][MLlib] Missing ClassTags added 22d96fe [Burak] [SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator, generic types added for RandomRDD instead of Double --- ...erator.scala => RandomDataGenerator.scala} | 18 +++++----- .../mllib/random/RandomRDDGenerators.scala | 32 +++++++++-------- .../apache/spark/mllib/rdd/RandomRDD.scala | 34 ++++++++++--------- ...e.scala => RandomDataGeneratorSuite.scala} | 6 ++-- .../random/RandomRDDGeneratorsSuite.scala | 8 +++-- 5 files changed, 52 insertions(+), 46 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/random/{DistributionGenerator.scala => RandomDataGenerator.scala} (80%) rename mllib/src/test/scala/org/apache/spark/mllib/random/{DistributionGeneratorSuite.scala => RandomDataGeneratorSuite.scala} (95%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala similarity index 80% rename from mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala rename to mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 7ecb409c4a91a..9cab49f6ed1f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -25,21 +25,21 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** * :: Experimental :: - * Trait for random number generators that generate i.i.d. values from a distribution. + * Trait for random data generators that generate i.i.d. data. */ @Experimental -trait DistributionGenerator extends Pseudorandom with Serializable { +trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** - * Returns an i.i.d. sample as a Double from an underlying distribution. + * Returns an i.i.d. sample as a generic type from an underlying distribution. */ - def nextValue(): Double + def nextValue(): T /** - * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the + * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ - def copy(): DistributionGenerator + def copy(): RandomDataGenerator[T] } /** @@ -47,7 +47,7 @@ trait DistributionGenerator extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @Experimental -class UniformGenerator extends DistributionGenerator { +class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -66,7 +66,7 @@ class UniformGenerator extends DistributionGenerator { * Generates i.i.d. samples from the standard normal distribution. */ @Experimental -class StandardNormalGenerator extends DistributionGenerator { +class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -87,7 +87,7 @@ class StandardNormalGenerator extends DistributionGenerator { * @param mean mean for the Poisson distribution. */ @Experimental -class PoissonGenerator(val mean: Double) extends DistributionGenerator { +class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index 021d651d4dbaa..b0a0593223910 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -24,6 +24,8 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag + /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. @@ -200,12 +202,12 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, numPartitions: Int, - seed: Long): RDD[Double] = { - new RandomRDD(sc, size, numPartitions, generator, seed) + seed: Long): RDD[T] = { + new RandomRDD[T](sc, size, numPartitions, generator, seed) } /** @@ -219,11 +221,11 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, - numPartitions: Int): RDD[Double] = { - randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): RDD[T] = { + randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) } /** @@ -237,10 +239,10 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, - size: Long): RDD[Double] = { - randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], + size: Long): RDD[T] = { + randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) } // TODO Generate RDD[Vector] from multivariate distributions. @@ -439,7 +441,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int, @@ -461,7 +463,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int): RDD[Vector] = { @@ -482,7 +484,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int): RDD[Vector] = { randomVectorRDD(sc, generator, numRows, numCols, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f13282d07ff92..c8db3910c6eab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} -import org.apache.spark.mllib.random.DistributionGenerator +import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag import scala.util.Random -private[mllib] class RandomRDDPartition(override val index: Int, +private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, - val generator: DistributionGenerator, + val generator: RandomDataGenerator[T], val seed: Long) extends Partition { require(size >= 0, "Non-negative partition size required.") } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD(@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: DistributionGenerator, - @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) { + @transient rng: RandomDataGenerator[T], + @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, "Partition size cannot exceed Int.MaxValue") - override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] - RandomRDD.getPointIterator(split) + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[RandomRDDPartition[T]] + RandomRDD.getPointIterator[T](split) } override def getPartitions: Array[Partition] = { @@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: DistributionGenerator, + @transient rng: RandomDataGenerator[Double], @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") @@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, "Partition size cannot exceed Int.MaxValue") override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] + val split = splitIn.asInstanceOf[RandomRDDPartition[Double]] RandomRDD.getVectorIterator(split, vectorSize) } @@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, private[mllib] object RandomRDD { - def getPartitions(size: Long, + def getPartitions[T](size: Long, numPartitions: Int, - rng: DistributionGenerator, + rng: RandomDataGenerator[T], seed: Long): Array[Partition] = { - val partitions = new Array[RandomRDDPartition](numPartitions) + val partitions = new Array[RandomRDDPartition[T]](numPartitions) var i = 0 var start: Long = 0 var end: Long = 0 @@ -101,7 +102,7 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = { + def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(generator.nextValue()).toIterator @@ -109,7 +110,8 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = { + def getVectorIterator(partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(new DenseVector( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala similarity index 95% rename from mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index 974dec4c0b5ee..3df7c128af5ab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -22,9 +22,9 @@ import org.scalatest.FunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class DistributionGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends FunSuite { - def apiChecks(gen: DistributionGenerator) { + def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers gen.setSeed(42L) @@ -53,7 +53,7 @@ class DistributionGeneratorSuite extends FunSuite { assert(array5.equals(array6)) } - def distributionChecks(gen: DistributionGenerator, + def distributionChecks(gen: RandomDataGenerator[Double], mean: Double = 0.0, stddev: Double = 1.0, epsilon: Double = 0.01) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala index 6aa4f803df0f7..96e0bc63b0fa4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala @@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri assert(rdd.partitions.size === numPartitions) // check that partition sizes are balanced - val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble) + val partSizes = rdd.partitions.map(p => + p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble) + val partStats = new StatCounter(partSizes) assert(partStats.max - partStats.min <= 1) } @@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) assert(rdd.partitions.size === numPartitions) val count = rdd.partitions.foldLeft(0L) { (count, part) => - count + part.asInstanceOf[RandomRDDPartition].size + count + part.asInstanceOf[RandomRDDPartition[Double]].size } assert(count === size) @@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri } } -private[random] class MockDistro extends DistributionGenerator { +private[random] class MockDistro extends RandomDataGenerator[Double] { var seed = 0L From 4bc3bb29a4b6ab24b6b7e1f8df26414c41c80ace Mon Sep 17 00:00:00 2001 From: Jeremy Freeman Date: Fri, 1 Aug 2014 22:33:25 -0700 Subject: [PATCH 107/170] StatCounter on NumPy arrays [PYSPARK][SPARK-2012] These changes allow StatCounters to work properly on NumPy arrays, to fix the issue reported here (https://issues.apache.org/jira/browse/SPARK-2012). If NumPy is installed, the NumPy functions ``maximum``, ``minimum``, and ``sqrt``, which work on arrays, are used to merge statistics. If not, we fall back on scalar operators, so it will work on arrays with NumPy, but will also work without NumPy. New unit tests added, along with a check for NumPy in the tests. Author: Jeremy Freeman Closes #1725 from freeman-lab/numpy-max-statcounter and squashes the following commits: fe973b1 [Jeremy Freeman] Avoid duplicate array import in tests 7f0e397 [Jeremy Freeman] Refactored check for numpy 8e764dd [Jeremy Freeman] Explicit numpy imports 875414c [Jeremy Freeman] Fixed indents 1c8a832 [Jeremy Freeman] Unit tests for StatCounter with NumPy arrays 176a127 [Jeremy Freeman] Use numpy arrays in StatCounter --- python/pyspark/statcounter.py | 21 +++++++++++++-------- python/pyspark/tests.py | 24 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index e287bd3da1f61..1e597d64e03fe 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,6 +20,13 @@ import copy import math +try: + from numpy import maximum, minimum, sqrt +except ImportError: + maximum = max + minimum = min + sqrt = math.sqrt + class StatCounter(object): @@ -39,10 +46,8 @@ def merge(self, value): self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) - if self.maxValue < value: - self.maxValue = value - if self.minValue > value: - self.minValue = value + self.maxValue = maximum(self.maxValue, value) + self.minValue = minimum(self.minValue, value) return self @@ -70,8 +75,8 @@ def mergeStats(self, other): else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - self.maxValue = max(self.maxValue, other.maxValue) - self.minValue = min(self.minValue, other.minValue) + self.maxValue = maximum(self.maxValue, other.maxValue) + self.minValue = minimum(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -115,14 +120,14 @@ def sampleVariance(self): # Return the standard deviation of the values. def stdev(self): - return math.sqrt(self.variance()) + return sqrt(self.variance()) # # Return the sample standard deviation of the values, which corrects for bias in estimating the # variance by dividing by N-1 instead of N. # def sampleStdev(self): - return math.sqrt(self.sampleVariance()) + return sqrt(self.sampleVariance()) def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c29deb9574ea2..16fb5a9256220 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -38,12 +38,19 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False +_have_numpy = False try: import scipy.sparse _have_scipy = True except: # No SciPy, but that's okay, we'll skip those tests pass +try: + import numpy as np + _have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass SPARK_HOME = os.environ["SPARK_HOME"] @@ -914,9 +921,26 @@ def test_serialize(self): self.assertEqual(expected, observed) +@unittest.skipIf(not _have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0,1.0], s.min().tolist()) + self.assertSequenceEqual([3.0,3.0], s.max().tolist()) + self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" + if not _have_numpy: + print "NOTE: Skipping NumPy tests as it does not seem to be installed" unittest.main() if not _have_scipy: print "NOTE: SciPy tests were skipped as it does not seem to be installed" + if not _have_numpy: + print "NOTE: NumPy tests were skipped as it does not seem to be installed" From adc8303294e26efb4ed15e5f5ba1062f7988625d Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 1 Aug 2014 23:55:11 -0700 Subject: [PATCH 108/170] [SPARK-1470][SPARK-1842] Use the scala-logging wrapper instead of the directly sfl4j api Author: GuoQiang Li Closes #1369 from witgo/SPARK-1470_new and squashes the following commits: 66a1641 [GuoQiang Li] IncompatibleResultTypeProblem 73a89ba [GuoQiang Li] Use the scala-logging wrapper instead of the directly sfl4j api. --- core/pom.xml | 4 + .../main/scala/org/apache/spark/Logging.scala | 39 +++++--- .../org/apache/spark/util/SignalLogger.scala | 2 +- mllib/pom.xml | 4 + pom.xml | 5 + project/MimaExcludes.scala | 91 ++++++++++++++++++- sql/catalyst/pom.xml | 5 - .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 +- .../apache/spark/sql/catalyst/package.scala | 1 - .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 +- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 +-- .../spark/sql/catalyst/trees/package.scala | 8 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../CompressibleColumnBuilder.scala | 5 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 2 - .../spark/sql/columnar/ColumnTypeSuite.scala | 4 +- .../hive/thriftserver/HiveThriftServer2.scala | 12 +-- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 6 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 +- .../server/SparkSQLOperationManager.scala | 13 +-- .../thriftserver/HiveThriftServer2Suite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../org/apache/spark/sql/hive/TestHive.scala | 10 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 +- .../hive/execution/HiveComparisonTest.scala | 22 ++--- .../hive/execution/HiveQueryFileTest.scala | 2 +- 35 files changed, 203 insertions(+), 97 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 7c60cf10c3dc2..47766ae5fbb3d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -98,6 +98,10 @@ org.slf4j jcl-over-slf4j + + com.typesafe.scala-logging + scala-logging-slf4j_${scala.binary.version} + log4j log4j diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 807ef3e9c9d60..6e61c00b8dbbf 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -18,8 +18,9 @@ package org.apache.spark import org.apache.log4j.{LogManager, PropertyConfigurator} -import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.LoggerFactory import org.slf4j.impl.StaticLoggerBinder +import com.typesafe.scalalogging.slf4j.Logger import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -39,61 +40,69 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null + // Method to get the logger name for this object + protected def logName = { + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + if (className.endsWith("$")) { + className = className.substring(0, className.length - 1) + } + className + } + // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - log_ = LoggerFactory.getLogger(className.stripSuffix("$")) + log_ = Logger(LoggerFactory.getLogger(logName)) } log_ } // Log methods that take only a String protected def logInfo(msg: => String) { - if (log.isInfoEnabled) log.info(msg) + log.info(msg) } protected def logDebug(msg: => String) { - if (log.isDebugEnabled) log.debug(msg) + log.debug(msg) } protected def logTrace(msg: => String) { - if (log.isTraceEnabled) log.trace(msg) + log.trace(msg) } protected def logWarning(msg: => String) { - if (log.isWarnEnabled) log.warn(msg) + log.warn(msg) } protected def logError(msg: => String) { - if (log.isErrorEnabled) log.error(msg) + log.error(msg) } // Log methods that take Throwables (Exceptions/Errors) too protected def logInfo(msg: => String, throwable: Throwable) { - if (log.isInfoEnabled) log.info(msg, throwable) + log.info(msg, throwable) } protected def logDebug(msg: => String, throwable: Throwable) { - if (log.isDebugEnabled) log.debug(msg, throwable) + log.debug(msg, throwable) } protected def logTrace(msg: => String, throwable: Throwable) { - if (log.isTraceEnabled) log.trace(msg, throwable) + log.trace(msg, throwable) } protected def logWarning(msg: => String, throwable: Throwable) { - if (log.isWarnEnabled) log.warn(msg, throwable) + log.warn(msg, throwable) } protected def logError(msg: => String, throwable: Throwable) { - if (log.isErrorEnabled) log.error(msg, throwable) + log.error(msg, throwable) } protected def isTraceEnabled(): Boolean = { - log.isTraceEnabled + log.underlying.isTraceEnabled } private def initializeIfNecessary() { diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala index f77488ef3d449..e84a6b951f65e 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import org.apache.commons.lang3.SystemUtils -import org.slf4j.Logger +import com.typesafe.scalalogging.slf4j.Logger import sun.misc.{Signal, SignalHandler} /** diff --git a/mllib/pom.xml b/mllib/pom.xml index 9a33bd1cf6ad1..3007681a44f1c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -59,6 +59,10 @@ breeze_${scala.binary.version} 0.7 + + com.typesafe + scalalogging-slf4j_${scala.binary.version} + diff --git a/pom.xml b/pom.xml index ae97bf03c53a2..9d62cea68995f 100644 --- a/pom.xml +++ b/pom.xml @@ -279,6 +279,11 @@ slf4j-log4j12 ${slf4j.version} + + com.typesafe.scala-logging + scala-logging-slf4j_${scala.binary.version} + 2.1.2 + org.slf4j jul-to-slf4j diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 537ca0dcf267d..a0cee1d765c7f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -103,14 +103,101 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) ++ - Seq ( // Package-private classes removed in SPARK-2341 + Seq( // Package-private classes removed in SPARK-2341 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) + ) ++ + Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.bagel.Bagel.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.StreamingContext.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.dstream.DStream.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.mllib.recommendation.ALS.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.mllib.clustering.KMeans.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.mllib.classification.NaiveBayes.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.kafka.KafkaReceiver.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.rdd.PairRDDFunctions.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.rdd.OrderedRDDFunctions.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.rdd.SequenceFileRDDFunctions.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.rdd.DoubleRDDFunctions.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.twitter.TwitterReceiver.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.flume.FlumeReceiver.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.rdd.RDD.log"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkConf.log"), + + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.SparkConf.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.StreamingContext.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.dstream.DStream.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.mllib.recommendation.ALS.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.mllib.clustering.KMeans.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.mllib.classification.NaiveBayes.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.rdd.RDD.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.rdd.SequenceFileRDDFunctions.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.rdd.OrderedRDDFunctions.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.rdd.DoubleRDDFunctions.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.flume.FlumeReceiver.org$apache$spark$Logging$$log__="), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleMethTypeProblem] + ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.flume.FlumeReceiver.org$apache$spark$Logging$$log_"), + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log_") + ) case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 54fa96baa1e18..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -54,11 +54,6 @@ spark-core_${scala.binary.version} ${project.version} - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - 1.0.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74c0104e5b17f..2b36582215f24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + log.trace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + log.debug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..eafbb70dc3fdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + log.debug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + log.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + log.debug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + log.debug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f38f99569f207..0913f15888780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4211998f7511a..e2552d432cb71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NumericType} @@ -92,7 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } new $orderingName() """ - logger.debug(s"Generated Ordering: $code") + log.debug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index ca9642954eb27..bdd07bbeb2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -25,5 +25,4 @@ package object catalyst { */ protected[catalyst] object ScalaReflectionLock - protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 781ba489b44c6..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bc763a4e06e67..06c5ffe92abc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -184,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + log.debug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -202,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + log.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index f8960b3fe7a17..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6aa407c836aec..20bf8eed7ddf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + log.trace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -73,26 +73,26 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + log.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + log.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + log.debug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + log.trace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 9a28d035a10a3..d725a92c06f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.Logging + /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -31,8 +33,8 @@ package org.apache.spark.sql.catalyst *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees { +package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) + protected override def logName = "catalyst.trees" + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dad71079c29b9..00dd34aabc389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 4c6675c3c87bf..828a8896ff60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -101,7 +102,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] copyColumnHeader(rawBuffer, compressedBuffer) - logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + log.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") encoder.compress(rawBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 30712f03cab4c..0c3d537ccb494 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -101,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + log.debug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 70db1ebd3a3e1..a3d2a1c7a51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 0995a4eb6299f..f513eae9c2d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -32,8 +32,6 @@ import org.apache.spark.annotation.DeveloperApi */ package object sql { - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - /** * :: DeveloperApi :: * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 829342215e691..a165531573a20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - logger.info("buffer = " + buffer + ", expected = " + expected) + log.info("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index ddbc2a79fb512..5959ba3d23f8e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -40,7 +40,7 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logger.warn("Error starting HiveThriftServer2 with given arguments") + log.warn("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -49,12 +49,12 @@ private[hive] object HiveThriftServer2 extends Logging { // Set all properties specified via command line. val hiveConf: HiveConf = ss.getConf hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logger.debug(s"HiveConf var: $k=$v") + log.debug(s"HiveConf var: $k=$v") } SessionState.start(ss) - logger.info("Starting SparkContext") + log.info("Starting SparkContext") SparkSQLEnv.init() SessionState.start(ss) @@ -70,10 +70,10 @@ private[hive] object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(hiveConf) server.start() - logger.info("HiveThriftServer2 started") + log.info("HiveThriftServer2 started") } catch { case e: Exception => - logger.error("Error starting HiveThriftServer2", e) + log.error("Error starting HiveThriftServer2", e) System.exit(-1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index cb17d7ce58ea0..4d0c506c5a397 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index a56b19a4bcda0..276723990b2ad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) @@ -40,7 +40,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed - logger.debug(s"Result Schema: ${analyzed.output}") + log.debug(s"Result Schema: ${analyzed.output}") if (analyzed.output.size == 0) { new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) } else { @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo new CommandProcessorResponse(0) } catch { case cause: Throwable => - logger.error(s"Failed in [$command]", cause) + log.error(s"Failed in [$command]", cause) new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 451c3bd7b9352..dfc93b19d019c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { - logger.debug("Initializing SparkSQLEnv") + log.debug("Initializing SparkSQLEnv") var hiveContext: HiveContext = _ var sparkContext: SparkContext = _ @@ -47,7 +47,7 @@ private[hive] object SparkSQLEnv extends Logging { /** Cleans up and shuts down the Spark SQL environments. */ def stop() { - logger.debug("Shutting down Spark SQL Environment") + log.debug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { sparkContext.stop() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a4e1f3e762e89..2c6e24e80d6dd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,10 +30,11 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -55,7 +56,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logger.debug("CLOSING") + log.debug("CLOSING") } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { @@ -112,7 +113,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + log.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { @@ -124,11 +125,11 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def run(): Unit = { - logger.info(s"Running query '$statement'") + log.info(s"Running query '$statement'") setState(OperationState.RUNNING) try { result = hiveContext.hql(statement) - logger.debug(result.queryExecution.toString()) + log.debug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = result.queryExecution.toRdd.toLocalIterator @@ -138,7 +139,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - logger.error("Error executing query:",e) + log.error("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index fe3403b3292ec..b7b7c9957ac34 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -27,7 +27,7 @@ import java.sql.{Connection, DriverManager, Statement} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7e3b8727bebed..1f31d35eaa10d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -207,7 +207,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - logger.error( + log.error( s""" |====================== |HIVE FAILURE OUTPUT diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fa4e78439c26c..df3604439e483 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLContext, Logging} +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c50e8c4b5c5d3..7376fb5dc83f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -148,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + log.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -273,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - logger.info(s"Loading test table $name") + log.info(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -312,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - logger.debug(s"Deleting table $t") + log.debug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -325,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logger.debug(s"Dropping Database: $db") + log.debug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -347,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + log.error(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7582b4743d404..4d8eaa18d7844 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - logger.debug( + log.debug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 6c8fe4b196dea..52cb1cf986f16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => log.debug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - logger.debug( + log.debug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - logger.debug(s"=== HIVE TEST: $testCaseName ===") + log.debug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,7 +235,7 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + log.warn(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + log.debug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - logger.debug(s"File $cachedAnswerFile not found") + log.debug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - logger.info(s"Using answer cache for test: $testCaseName") + log.info(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + log.warn(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - logger.warn(s"Clearing cache files for failed test $testCaseName") + log.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + log.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 50ab71a9003d3..9ca5575c1be8a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logger.debug(s"Blacklisted test skipped $testCaseName") + log.debug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) From dab37966b0cfd290919ca5c005f59dde00615c0e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 1 Aug 2014 23:55:30 -0700 Subject: [PATCH 109/170] Revert "[SPARK-1470][SPARK-1842] Use the scala-logging wrapper instead of the directly sfl4j api" This reverts commit adc8303294e26efb4ed15e5f5ba1062f7988625d. --- core/pom.xml | 4 - .../main/scala/org/apache/spark/Logging.scala | 39 +++----- .../org/apache/spark/util/SignalLogger.scala | 2 +- mllib/pom.xml | 4 - pom.xml | 5 - project/MimaExcludes.scala | 91 +------------------ sql/catalyst/pom.xml | 5 + .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 +- .../apache/spark/sql/catalyst/package.scala | 1 + .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 +- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 +-- .../spark/sql/catalyst/trees/package.scala | 8 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../CompressibleColumnBuilder.scala | 5 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 2 + .../spark/sql/columnar/ColumnTypeSuite.scala | 4 +- .../hive/thriftserver/HiveThriftServer2.scala | 12 +-- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 6 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 +- .../server/SparkSQLOperationManager.scala | 13 ++- .../thriftserver/HiveThriftServer2Suite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../org/apache/spark/sql/hive/TestHive.scala | 10 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 +- .../hive/execution/HiveComparisonTest.scala | 22 ++--- .../hive/execution/HiveQueryFileTest.scala | 2 +- 35 files changed, 97 insertions(+), 203 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 47766ae5fbb3d..7c60cf10c3dc2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -98,10 +98,6 @@ org.slf4j jcl-over-slf4j
    - - com.typesafe.scala-logging - scala-logging-slf4j_${scala.binary.version} - log4j log4j diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 6e61c00b8dbbf..807ef3e9c9d60 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -18,9 +18,8 @@ package org.apache.spark import org.apache.log4j.{LogManager, PropertyConfigurator} -import org.slf4j.LoggerFactory +import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder -import com.typesafe.scalalogging.slf4j.Logger import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -40,69 +39,61 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null - // Method to get the logger name for this object - protected def logName = { - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - className - } - // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - log_ = Logger(LoggerFactory.getLogger(logName)) + var className = this.getClass.getName + // Ignore trailing $'s in the class names for Scala objects + log_ = LoggerFactory.getLogger(className.stripSuffix("$")) } log_ } // Log methods that take only a String protected def logInfo(msg: => String) { - log.info(msg) + if (log.isInfoEnabled) log.info(msg) } protected def logDebug(msg: => String) { - log.debug(msg) + if (log.isDebugEnabled) log.debug(msg) } protected def logTrace(msg: => String) { - log.trace(msg) + if (log.isTraceEnabled) log.trace(msg) } protected def logWarning(msg: => String) { - log.warn(msg) + if (log.isWarnEnabled) log.warn(msg) } protected def logError(msg: => String) { - log.error(msg) + if (log.isErrorEnabled) log.error(msg) } // Log methods that take Throwables (Exceptions/Errors) too protected def logInfo(msg: => String, throwable: Throwable) { - log.info(msg, throwable) + if (log.isInfoEnabled) log.info(msg, throwable) } protected def logDebug(msg: => String, throwable: Throwable) { - log.debug(msg, throwable) + if (log.isDebugEnabled) log.debug(msg, throwable) } protected def logTrace(msg: => String, throwable: Throwable) { - log.trace(msg, throwable) + if (log.isTraceEnabled) log.trace(msg, throwable) } protected def logWarning(msg: => String, throwable: Throwable) { - log.warn(msg, throwable) + if (log.isWarnEnabled) log.warn(msg, throwable) } protected def logError(msg: => String, throwable: Throwable) { - log.error(msg, throwable) + if (log.isErrorEnabled) log.error(msg, throwable) } protected def isTraceEnabled(): Boolean = { - log.underlying.isTraceEnabled + log.isTraceEnabled } private def initializeIfNecessary() { diff --git a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala index e84a6b951f65e..f77488ef3d449 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalLogger.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import org.apache.commons.lang3.SystemUtils -import com.typesafe.scalalogging.slf4j.Logger +import org.slf4j.Logger import sun.misc.{Signal, SignalHandler} /** diff --git a/mllib/pom.xml b/mllib/pom.xml index 3007681a44f1c..9a33bd1cf6ad1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -59,10 +59,6 @@ breeze_${scala.binary.version} 0.7 - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - diff --git a/pom.xml b/pom.xml index 9d62cea68995f..ae97bf03c53a2 100644 --- a/pom.xml +++ b/pom.xml @@ -279,11 +279,6 @@ slf4j-log4j12 ${slf4j.version} - - com.typesafe.scala-logging - scala-logging-slf4j_${scala.binary.version} - 2.1.2 - org.slf4j jul-to-slf4j diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a0cee1d765c7f..537ca0dcf267d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -103,101 +103,14 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) ++ - Seq( // Package-private classes removed in SPARK-2341 + Seq ( // Package-private classes removed in SPARK-2341 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ - Seq( - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.bagel.Bagel.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.StreamingContext.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.dstream.DStream.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.mllib.recommendation.ALS.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.mllib.clustering.KMeans.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.mllib.classification.NaiveBayes.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.kafka.KafkaReceiver.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkContext.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.rdd.PairRDDFunctions.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.rdd.OrderedRDDFunctions.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.rdd.SequenceFileRDDFunctions.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.rdd.DoubleRDDFunctions.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.twitter.TwitterReceiver.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.flume.FlumeReceiver.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.rdd.RDD.log"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkConf.log"), - - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.SparkConf.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.StreamingContext.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.dstream.DStream.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.mllib.recommendation.ALS.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.mllib.clustering.KMeans.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.mllib.classification.NaiveBayes.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.SparkContext.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.rdd.RDD.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.rdd.SequenceFileRDDFunctions.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.rdd.OrderedRDDFunctions.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.rdd.DoubleRDDFunctions.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.flume.FlumeReceiver.org$apache$spark$Logging$$log__="), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleMethTypeProblem] - ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.twitter.TwitterReceiver.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.zeromq.ZeroMQReceiver.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.bagel.Bagel.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.flume.FlumeReceiver.org$apache$spark$Logging$$log_"), - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.streaming.kafka.KafkaReceiver.org$apache$spark$Logging$$log_") - ) + ) case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..54fa96baa1e18 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -54,6 +54,11 @@ spark-core_${scala.binary.version} ${project.version} + + com.typesafe + scalalogging-slf4j_${scala.binary.version} + 1.0.1 + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2b36582215f24..74c0104e5b17f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - log.trace(s"Attempting to resolve ${q.simpleString}") + logger.trace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - log.debug(s"Resolving $u to $result") + logger.debug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index eafbb70dc3fdd..47c7ad076ad07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - log.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - log.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - log.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - log.debug(s"Widening numeric types in union $castedRight ${right.output}") + logger.debug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 0913f15888780..f38f99569f207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index e2552d432cb71..4211998f7511a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.Logging +import com.typesafe.scalalogging.slf4j.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NumericType} @@ -92,7 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } new $orderingName() """ - log.debug(s"Generated Ordering: $code") + logger.debug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index bdd07bbeb2230..ca9642954eb27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -25,4 +25,5 @@ package object catalyst { */ protected[catalyst] object ScalaReflectionLock + protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..781ba489b44c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 06c5ffe92abc8..bc763a4e06e67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -184,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - log.debug(s"Considering join on: $condition") + logger.debug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -202,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - log.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 03414b2301e81..f8960b3fe7a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 20bf8eed7ddf3..6aa407c836aec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - log.trace( + logger.trace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -73,26 +73,26 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - log.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } continue = false } if (curPlan.fastEquals(lastPlan)) { - log.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - log.debug( + logger.debug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - log.trace(s"Batch ${batch.name} has no effect.") + logger.trace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d725a92c06f7b..9a28d035a10a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.Logging - /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -33,8 +31,8 @@ import org.apache.spark.Logging *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees extends Logging { +package object trees { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected override def logName = "catalyst.trees" - + protected val logger = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 00dd34aabc389..dad71079c29b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 828a8896ff60a..4c6675c3c87bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Logging, Row} import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -102,7 +101,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] copyColumnHeader(rawBuffer, compressedBuffer) - log.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") encoder.compress(rawBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 0c3d537ccb494..30712f03cab4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -101,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - log.debug( + logger.debug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index a3d2a1c7a51f8..70db1ebd3a3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.Logging +import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f513eae9c2d13..0995a4eb6299f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -32,6 +32,8 @@ import org.apache.spark.annotation.DeveloperApi */ package object sql { + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + /** * :: DeveloperApi :: * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index a165531573a20..829342215e691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - log.info("buffer = " + buffer + ", expected = " + expected) + logger.info("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 5959ba3d23f8e..ddbc2a79fb512 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -40,7 +40,7 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - log.warn("Error starting HiveThriftServer2 with given arguments") + logger.warn("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -49,12 +49,12 @@ private[hive] object HiveThriftServer2 extends Logging { // Set all properties specified via command line. val hiveConf: HiveConf = ss.getConf hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - log.debug(s"HiveConf var: $k=$v") + logger.debug(s"HiveConf var: $k=$v") } SessionState.start(ss) - log.info("Starting SparkContext") + logger.info("Starting SparkContext") SparkSQLEnv.init() SessionState.start(ss) @@ -70,10 +70,10 @@ private[hive] object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(hiveConf) server.start() - log.info("HiveThriftServer2 started") + logger.info("HiveThriftServer2 started") } catch { case e: Exception => - log.error("Error starting HiveThriftServer2", e) + logger.error("Error starting HiveThriftServer2", e) System.exit(-1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4d0c506c5a397..cb17d7ce58ea0 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket -import org.apache.spark.Logging +import org.apache.spark.sql.Logging private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 276723990b2ad..a56b19a4bcda0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) @@ -40,7 +40,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed - log.debug(s"Result Schema: ${analyzed.output}") + logger.debug(s"Result Schema: ${analyzed.output}") if (analyzed.output.size == 0) { new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) } else { @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo new CommandProcessorResponse(0) } catch { case cause: Throwable => - log.error(s"Failed in [$command]", cause) + logger.error(s"Failed in [$command]", cause) new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index dfc93b19d019c..451c3bd7b9352 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { - log.debug("Initializing SparkSQLEnv") + logger.debug("Initializing SparkSQLEnv") var hiveContext: HiveContext = _ var sparkContext: SparkContext = _ @@ -47,7 +47,7 @@ private[hive] object SparkSQLEnv extends Logging { /** Cleans up and shuts down the Spark SQL environments. */ def stop() { - log.debug("Shutting down Spark SQL Environment") + logger.debug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { sparkContext.stop() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 2c6e24e80d6dd..a4e1f3e762e89 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,11 +30,10 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -56,7 +55,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - log.debug("CLOSING") + logger.debug("CLOSING") } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { @@ -113,7 +112,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - log.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { @@ -125,11 +124,11 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def run(): Unit = { - log.info(s"Running query '$statement'") + logger.info(s"Running query '$statement'") setState(OperationState.RUNNING) try { result = hiveContext.hql(statement) - log.debug(result.queryExecution.toString()) + logger.debug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = result.queryExecution.toRdd.toLocalIterator @@ -139,7 +138,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - log.error("Error executing query:",e) + logger.error("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index b7b7c9957ac34..fe3403b3292ec 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -27,7 +27,7 @@ import java.sql.{Connection, DriverManager, Statement} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1f31d35eaa10d..7e3b8727bebed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -207,7 +207,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - log.error( + logger.error( s""" |====================== |HIVE FAILURE OUTPUT diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index df3604439e483..fa4e78439c26c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,8 +28,7 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLContext, Logging} import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 7376fb5dc83f8..c50e8c4b5c5d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -148,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - log.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -273,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - log.info(s"Loading test table $name") + logger.info(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -312,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - log.debug(s"Deleting table $t") + logger.debug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -325,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - log.debug(s"Dropping Database: $db") + logger.debug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -347,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - log.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 4d8eaa18d7844..7582b4743d404 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - log.debug( + logger.debug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 52cb1cf986f16..6c8fe4b196dea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => log.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - log.debug( + logger.debug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - log.debug(s"=== HIVE TEST: $testCaseName ===") + logger.debug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,7 +235,7 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - log.warn(s"Simplifications made on unsupported operations for test $testCaseName") + logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - log.debug(s"Looking for cached answer file $cachedAnswerFile.") + logger.debug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - log.debug(s"File $cachedAnswerFile not found") + logger.debug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - log.info(s"Using answer cache for test: $testCaseName") + logger.info(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - log.warn(s"Running query ${i+1}/${queryList.size} with hive.") + logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - log.warn(s"Clearing cache files for failed test $testCaseName") + logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - log.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 9ca5575c1be8a..50ab71a9003d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - log.debug(s"Blacklisted test skipped $testCaseName") + logger.debug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) From d934801d53fc2f1d57d3534ae4e1e9384c7dda99 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 1 Aug 2014 23:56:24 -0700 Subject: [PATCH 110/170] [SPARK-2316] Avoid O(blocks) operations in listeners The existing code in `StorageUtils` is not the most efficient. Every time we want to update an `RDDInfo` we end up iterating through all blocks on all block managers just to discard most of them. The symptoms manifest themselves in the bountiful UI bugs observed in the wild. Many of these bugs are caused by the slow consumption of events in `LiveListenerBus`, which frequently leads to the event queue overflowing and `SparkListenerEvent`s being dropped on the floor. The changes made in this PR avoid this by first filtering out only the blocks relevant to us before computing storage information from them. It's worth a mention that this corner of the Spark code is also not very well-tested at all. The bulk of the changes in this PR (more than 60%) is actually test cases for the various logic in `StorageUtils.scala` as well as `StorageTab.scala`. These will eventually be extended to cover the various listeners that constitute the `SparkUI`. Author: Andrew Or Closes #1679 from andrewor14/fix-drop-events and squashes the following commits: f80c1fa [Andrew Or] Rewrite fold and reduceOption as sum e132d69 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-drop-events 14fa1c3 [Andrew Or] Simplify some code + update a few comments a91be46 [Andrew Or] Make ExecutorsPage blazingly fast bf6f09b [Andrew Or] Minor changes 8981de1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-drop-events af19bc0 [Andrew Or] *UsedByRDD -> *UsedByRdd (minor) 6970bc8 [Andrew Or] Add extensive tests for StorageListener and the new code in StorageUtils e080b9e [Andrew Or] Reduce run time of StorageUtils.updateRddInfo to near constant 2c3ef6a [Andrew Or] Actually filter out only the relevant RDDs 6fef86a [Andrew Or] Add extensive tests for new code in StorageStatus b66b6b0 [Andrew Or] Use more efficient underlying data structures for blocks 6a7b7c0 [Andrew Or] Avoid chained operations on TraversableLike a9ec384 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-drop-events b12fcd7 [Andrew Or] Fix tests + simplify sc.getRDDStorageInfo da8e322 [Andrew Or] Merge branch 'master' of github.com:apache/spark into fix-drop-events 8e91921 [Andrew Or] Iterate through a filtered set of blocks when updating RDDInfo 7b2c4aa [Andrew Or] Rewrite blockLocationsFromStorageStatus + clean up method signatures 41fa50d [Andrew Or] Add a legacy constructor for StorageStatus 53af15d [Andrew Or] Refactor StorageStatus + add a bunch of tests --- .../scala/org/apache/spark/SparkContext.scala | 6 +- .../storage/BlockManagerMasterActor.scala | 14 +- .../spark/storage/BlockManagerSource.scala | 14 +- .../org/apache/spark/storage/RDDInfo.scala | 2 + .../spark/storage/StorageStatusListener.scala | 12 +- .../apache/spark/storage/StorageUtils.scala | 316 +++++++++++----- .../apache/spark/ui/exec/ExecutorsPage.scala | 12 +- .../org/apache/spark/ui/storage/RDDPage.scala | 17 +- .../apache/spark/ui/storage/StorageTab.scala | 13 +- .../apache/spark/SparkContextInfoSuite.scala | 22 +- .../storage/StorageStatusListenerSuite.scala | 72 ++-- .../apache/spark/storage/StorageSuite.scala | 354 ++++++++++++++++++ .../spark/ui/storage/StorageTabSuite.scala | 165 ++++++++ 13 files changed, 843 insertions(+), 176 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/StorageSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 368835a867493..9ba21cfcde01a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,7 +48,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend -import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} +import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} @@ -843,7 +843,9 @@ class SparkContext(config: SparkConf) extends Logging { */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 94f5a4bb2e9cd..bd31e3c5a187f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -267,9 +267,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case(blockManagerId, info) => - val blockMap = mutable.Map[BlockId, BlockStatus](info.blocks.toSeq: _*) - new StorageStatus(blockManagerId, info.maxMem, blockMap) + blockManagerInfo.map { case (blockManagerId, info) => + new StorageStatus(blockManagerId, info.maxMem, info.blocks) }.toArray } @@ -424,7 +423,14 @@ case class BlockStatus( storageLevel: StorageLevel, memSize: Long, diskSize: Long, - tachyonSize: Long) + tachyonSize: Long) { + def isCached: Boolean = memSize + diskSize + tachyonSize > 0 +} + +@DeveloperApi +object BlockStatus { + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) +} private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 687586490abfe..e939318a029dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -30,7 +30,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) + val maxMem = storageStatusList.map(_.maxMem).sum maxMem / 1024 / 1024 } }) @@ -38,7 +38,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) + val remainingMem = storageStatusList.map(_.memRemaining).sum remainingMem / 1024 / 1024 } }) @@ -46,8 +46,8 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) + val maxMem = storageStatusList.map(_.maxMem).sum + val remainingMem = storageStatusList.map(_.memRemaining).sum (maxMem - remainingMem) / 1024 / 1024 } }) @@ -55,11 +55,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList - .flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_ + _) - .getOrElse(0L) - + val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum diskSpaceUsed / 1024 / 1024 } }) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 5a72e216872a6..120c327a7e580 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -34,6 +34,8 @@ class RDDInfo( var diskSize = 0L var tachyonSize = 0L + def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 + override def toString = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 41c960c867e2e..d9066f766476e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -35,13 +35,12 @@ class StorageStatusListener extends SparkListener { /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { - val filteredStatus = executorIdToStorageStatus.get(execId) - filteredStatus.foreach { storageStatus => + executorIdToStorageStatus.get(execId).foreach { storageStatus => updatedBlocks.foreach { case (blockId, updatedStatus) => if (updatedStatus.storageLevel == StorageLevel.NONE) { - storageStatus.blocks.remove(blockId) + storageStatus.removeBlock(blockId) } else { - storageStatus.blocks(blockId) = updatedStatus + storageStatus.updateBlock(blockId, updatedStatus) } } } @@ -50,9 +49,8 @@ class StorageStatusListener extends SparkListener { /** Update storage status list to reflect the removal of an RDD from the cache */ private def updateStorageStatus(unpersistedRDDId: Int) { storageStatusList.foreach { storageStatus => - val unpersistedBlocksIds = storageStatus.rddBlocks.keys.filter(_.rddId == unpersistedRDDId) - unpersistedBlocksIds.foreach { blockId => - storageStatus.blocks.remove(blockId) + storageStatus.rddBlocksById(unpersistedRDDId).foreach { case (blockId, _) => + storageStatus.removeBlock(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 177281f663367..0a0a448baa2ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -20,122 +20,258 @@ package org.apache.spark.storage import scala.collection.Map import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * Storage information for each BlockManager. + * + * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this + * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus( - val blockManagerId: BlockManagerId, - val maxMem: Long, - val blocks: mutable.Map[BlockId, BlockStatus] = mutable.Map.empty) { +class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { - def memUsed = blocks.values.map(_.memSize).reduceOption(_ + _).getOrElse(0L) + /** + * Internal representation of the blocks stored in this block manager. + * + * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. + * These collections should only be mutated through the add/update/removeBlock methods. + */ + private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] + private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - def memUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_ + _).getOrElse(0L) + /** + * Storage information of the blocks that entails memory, disk, and off-heap memory usage. + * + * As with the block maps, we store the storage information separately for RDD blocks and + * non-RDD blocks for the same reason. In particular, RDD storage information is stored + * in a map indexed by the RDD ID to the following 4-tuple: + * + * (memory size, disk size, off-heap size, storage level) + * + * We assume that all the blocks that belong to the same RDD have the same storage level. + * This field is not relevant to non-RDD blocks, however, so the storage information for + * non-RDD blocks contains only the first 3 fields (in the same order). + */ + private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, Long, StorageLevel)] + private var _nonRddStorageInfo: (Long, Long, Long) = (0L, 0L, 0L) - def diskUsed = blocks.values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) + /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ + def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMem) + initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } + } - def diskUsedByRDD(rddId: Int) = - rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) + /** + * Return the blocks stored in this block manager. + * + * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * concatenating them together. Much faster alternatives exist for common operations such as + * contains, get, and size. + */ + def blocks: Map[BlockId, BlockStatus] = _nonRddBlocks ++ rddBlocks - def memRemaining: Long = maxMem - memUsed + /** + * Return the RDD blocks stored in this block manager. + * + * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * concatenating them together. Much faster alternatives exist for common operations such as + * getting the memory, disk, and off-heap memory sizes occupied by this RDD. + */ + def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) } -} + /** Return the blocks that belong to the given RDD stored in this block manager. */ + def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = { + _rddBlocks.get(rddId).getOrElse(Map.empty) + } -/** Helper methods for storage-related objects. */ -private[spark] object StorageUtils { + /** Add the given block to this storage status. If it already exists, overwrite it. */ + private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { + updateStorageInfo(blockId, blockStatus) + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.getOrElseUpdate(rddId, new mutable.HashMap)(blockId) = blockStatus + case _ => + _nonRddBlocks(blockId) = blockStatus + } + } + + /** Update the given block in this storage status. If it doesn't already exist, add it. */ + private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { + addBlock(blockId, blockStatus) + } + + /** Remove the given block from this storage status. */ + private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { + updateStorageInfo(blockId, BlockStatus.empty) + blockId match { + case RDDBlockId(rddId, _) => + // Actually remove the block, if it exists + if (_rddBlocks.contains(rddId)) { + val removed = _rddBlocks(rddId).remove(blockId) + // If the given RDD has no more blocks left, remove the RDD + if (_rddBlocks(rddId).isEmpty) { + _rddBlocks.remove(rddId) + } + removed + } else { + None + } + case _ => + _nonRddBlocks.remove(blockId) + } + } /** - * Returns basic information of all RDDs persisted in the given SparkContext. This does not - * include storage information. + * Return whether the given block is stored in this block manager in O(1) time. + * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. */ - def rddInfoFromSparkContext(sc: SparkContext): Array[RDDInfo] = { - sc.persistentRdds.values.map { rdd => - val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - val rddNumPartitions = rdd.partitions.size - val rddStorageLevel = rdd.getStorageLevel - val rddInfo = new RDDInfo(rdd.id, rddName, rddNumPartitions, rddStorageLevel) - rddInfo - }.toArray + def containsBlock(blockId: BlockId): Boolean = { + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.get(rddId).exists(_.contains(blockId)) + case _ => + _nonRddBlocks.contains(blockId) + } } - /** Returns storage information of all RDDs persisted in the given SparkContext. */ - def rddInfoFromStorageStatus( - storageStatuses: Seq[StorageStatus], - sc: SparkContext): Array[RDDInfo] = { - rddInfoFromStorageStatus(storageStatuses, rddInfoFromSparkContext(sc)) + /** + * Return the given block stored in this block manager in O(1) time. + * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + */ + def getBlock(blockId: BlockId): Option[BlockStatus] = { + blockId match { + case RDDBlockId(rddId, _) => + _rddBlocks.get(rddId).map(_.get(blockId)).flatten + case _ => + _nonRddBlocks.get(blockId) + } } - /** Returns storage information of all RDDs in the given list. */ - def rddInfoFromStorageStatus( - storageStatuses: Seq[StorageStatus], - rddInfos: Seq[RDDInfo], - updatedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty): Array[RDDInfo] = { - - // Mapping from a block ID -> its status - val blockMap = mutable.Map(storageStatuses.flatMap(_.rddBlocks): _*) - - // Record updated blocks, if any - updatedBlocks - .collect { case (id: RDDBlockId, status) => (id, status) } - .foreach { case (id, status) => blockMap(id) = status } - - // Mapping from RDD ID -> an array of associated BlockStatuses - val rddBlockMap = blockMap - .groupBy { case (k, _) => k.rddId } - .mapValues(_.values.toArray) - - // Mapping from RDD ID -> the associated RDDInfo (with potentially outdated storage information) - val rddInfoMap = rddInfos.map { info => (info.id, info) }.toMap - - val rddStorageInfos = rddBlockMap.flatMap { case (rddId, blocks) => - // Add up memory, disk and Tachyon sizes - val persistedBlocks = - blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 } - val _storageLevel = - if (persistedBlocks.length > 0) persistedBlocks(0).storageLevel else StorageLevel.NONE - val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L) - val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) - val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L) - rddInfoMap.get(rddId).map { rddInfo => - rddInfo.storageLevel = _storageLevel - rddInfo.numCachedPartitions = persistedBlocks.length - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - rddInfo.tachyonSize = tachyonSize - rddInfo - } - }.toArray + /** + * Return the number of blocks stored in this block manager in O(RDDs) time. + * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + */ + def numBlocks: Int = _nonRddBlocks.size + numRddBlocks + + /** + * Return the number of RDD blocks stored in this block manager in O(RDDs) time. + * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + */ + def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - scala.util.Sorting.quickSort(rddStorageInfos) - rddStorageInfos + /** + * Return the number of blocks that belong to the given RDD in O(1) time. + * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * O(blocks in this RDD) time. + */ + def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + + /** Return the memory remaining in this block manager. */ + def memRemaining: Long = maxMem - memUsed + + /** Return the memory used by this block manager. */ + def memUsed: Long = + _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + + /** Return the disk space used by this block manager. */ + def diskUsed: Long = + _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + + /** Return the off-heap space used by this block manager. */ + def offHeapUsed: Long = + _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum + + /** Return the memory used by the given RDD in this block manager in O(1) time. */ + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + + /** Return the disk space used by the given RDD in this block manager in O(1) time. */ + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + + /** Return the off-heap space used by the given RDD in this block manager in O(1) time. */ + def offHeapUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._3).getOrElse(0L) + + /** Return the storage level, if any, used by the given RDD in this block manager. */ + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._4) + + /** + * Update the relevant storage info, taking into account any existing status for this block. + */ + private def updateStorageInfo(blockId: BlockId, newBlockStatus: BlockStatus): Unit = { + val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty) + val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize + val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize + val changeInTachyon = newBlockStatus.tachyonSize - oldBlockStatus.tachyonSize + val level = newBlockStatus.storageLevel + + // Compute new info from old info + val (oldMem, oldDisk, oldTachyon) = blockId match { + case RDDBlockId(rddId, _) => + _rddStorageInfo.get(rddId) + .map { case (mem, disk, tachyon, _) => (mem, disk, tachyon) } + .getOrElse((0L, 0L, 0L)) + case _ => + _nonRddStorageInfo + } + val newMem = math.max(oldMem + changeInMem, 0L) + val newDisk = math.max(oldDisk + changeInDisk, 0L) + val newTachyon = math.max(oldTachyon + changeInTachyon, 0L) + + // Set the correct info + blockId match { + case RDDBlockId(rddId, _) => + // If this RDD is no longer persisted, remove it + if (newMem + newDisk + newTachyon == 0) { + _rddStorageInfo.remove(rddId) + } else { + _rddStorageInfo(rddId) = (newMem, newDisk, newTachyon, level) + } + case _ => + _nonRddStorageInfo = (newMem, newDisk, newTachyon) + } } - /** Returns a mapping from BlockId to the locations of the associated block. */ - def blockLocationsFromStorageStatus( - storageStatuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocationPairs = storageStatuses.flatMap { storageStatus => - storageStatus.blocks.map { case (bid, _) => (bid, storageStatus.blockManagerId.hostPort) } +} + +/** Helper methods for storage-related objects. */ +private[spark] object StorageUtils { + + /** + * Update the given list of RDDInfo with the given list of storage statuses. + * This method overwrites the old values stored in the RDDInfo's. + */ + def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + // Assume all blocks belonging to the same RDD have the same storage level + val storageLevel = statuses + .map(_.rddStorageLevel(rddId)).flatMap(s => s).headOption.getOrElse(StorageLevel.NONE) + val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum + val memSize = statuses.map(_.memUsedByRdd(rddId)).sum + val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum + val tachyonSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum + + rddInfo.storageLevel = storageLevel + rddInfo.numCachedPartitions = numCachedPartitions + rddInfo.memSize = memSize + rddInfo.diskSize = diskSize + rddInfo.tachyonSize = tachyonSize } - blockLocationPairs.toMap - .groupBy { case (blockId, _) => blockId } - .mapValues(_.values.toSeq) } - /** Filters the given list of StorageStatus by the given RDD ID. */ - def filterStorageStatusByRDD( - storageStatuses: Seq[StorageStatus], - rddId: Int): Array[StorageStatus] = { - storageStatuses.map { status => - val filteredBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toSeq - val filteredBlockMap = mutable.Map[BlockId, BlockStatus](filteredBlocks: _*) - new StorageStatus(status.blockManagerId, status.maxMem, filteredBlockMap) - }.toArray + /** + * Return a mapping from block ID to its locations for each block that belongs to the given RDD. + */ + def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { + val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] + statuses.foreach { status => + status.rddBlocksById(rddId).foreach { case (bid, _) => + val location = status.blockManagerId.hostPort + blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location + } + } + blockLocations } + } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b358c855e1c88..b814b0e6b8509 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -49,9 +49,9 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val storageStatusList = listener.storageStatusList - val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_ + _) - val memUsed = storageStatusList.map(_.memUsed).fold(0L)(_ + _) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_ + _) + val maxMem = storageStatusList.map(_.maxMem).sum + val memUsed = storageStatusList.map(_.memUsed).sum + val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execInfoSorted = execInfo.sortBy(_.id) @@ -80,7 +80,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { - {execInfoSorted.map(execRow(_))} + {execInfoSorted.map(execRow)} @@ -91,7 +91,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
  • Memory: {Utils.bytesToString(memUsed)} Used ({Utils.bytesToString(maxMem)} Total)
  • -
  • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
  • +
  • Disk: {Utils.bytesToString(diskUsed)} Used
  • @@ -145,7 +145,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { val status = listener.storageStatusList(statusId) val execId = status.blockManagerId.executorId val hostPort = status.blockManagerId.hostPort - val rddBlocks = status.blocks.size + val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem val diskUsed = status.diskUsed diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2155633b8096f..84ac53da47552 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -45,12 +45,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers) // Block table - val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, rddId) - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).sortWith(_._1.name < _._1.name) - val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) - val blocks = blockStatuses.map { case (blockId, status) => - (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown"))) - } + val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList) + val blocks = storageStatusList + .flatMap(_.rddBlocksById(rddId)) + .sortWith(_._1.name < _._1.name) + .map { case (blockId, status) => + (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown"))) + } val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks) val content = @@ -119,10 +120,10 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {status.blockManagerId.host + ":" + status.blockManagerId.port} - {Utils.bytesToString(status.memUsedByRDD(rddId))} + {Utils.bytesToString(status.memUsedByRdd(rddId))} ({Utils.bytesToString(status.memRemaining)} Remaining) - {Utils.bytesToString(status.diskUsedByRDD(rddId))} + {Utils.bytesToString(status.diskUsedByRdd(rddId))} } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 0cc0cf3117173..5f6740d495521 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -41,19 +41,18 @@ private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage" */ @DeveloperApi class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { - private val _rddInfoMap = mutable.Map[Int, RDDInfo]() + private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing def storageStatusList = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq - /** Update each RDD's info to reflect any updates to the RDD's storage status */ - private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty) { - val rddInfos = _rddInfoMap.values.toSeq - val updatedRddInfos = - StorageUtils.rddInfoFromStorageStatus(storageStatusList, rddInfos, updatedBlocks) - updatedRddInfos.foreach { info => _rddInfoMap(info.id) = info } + /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ + private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { + val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet + val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) } + StorageUtils.updateRddInfo(rddInfosToUpdate, storageStatusList) } /** diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index fb18c3ebfe46f..e6ab538d77bcc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.scalatest.{Assertions, FunSuite} +import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends FunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { @@ -35,26 +36,33 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext { test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) - assert(myRdds.values.head === rdd1) + assert(myRdds(0) === rdd1) + assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) + // myRdds2 should have 2 RDDs, but myRdds should not change val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() - - // getPersistentRDDs should have 2 RDDs, but myRdds should not change - assert(sc.getPersistentRDDs.size === 2) + val myRdds2 = sc.getPersistentRDDs + assert(myRdds2.size === 2) + assert(myRdds2(0) === rdd1) + assert(myRdds2(1) === rdd2) + assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY) + assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY) assert(myRdds.size === 1) + assert(myRdds(0) === rdd1) + assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) } test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - assert(sc.getRDDStorageInfo.size === 0) - rdd.collect() assert(sc.getRDDStorageInfo.size === 1) + assert(sc.getRDDStorageInfo.head.isCached) + assert(sc.getRDDStorageInfo.head.memSize > 0) + assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY) } test("call sites report correct locations") { diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 2179c6dd3302e..51fb646a3cb61 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -41,13 +41,13 @@ class StorageStatusListenerSuite extends FunSuite { assert(listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1) assert(listener.executorIdToStorageStatus("big").maxMem === 1000L) - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) assert(listener.executorIdToStorageStatus.size === 2) assert(listener.executorIdToStorageStatus.get("fat").isDefined) assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2) assert(listener.executorIdToStorageStatus("fat").maxMem === 2000L) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) // Block manager remove listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm1)) @@ -67,14 +67,14 @@ class StorageStatusListenerSuite extends FunSuite { val taskMetrics = new TaskMetrics // Task end with no updated blocks - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics)) - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) } test("task end with updated blocks") { @@ -90,20 +90,20 @@ class StorageStatusListenerSuite extends FunSuite { taskMetrics2.updatedBlocks = Some(Seq(block3)) // Task end with new blocks - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 2) - assert(listener.executorIdToStorageStatus("fat").blocks.size === 0) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 2) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 2) - assert(listener.executorIdToStorageStatus("fat").blocks.size === 1) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").blocks.contains(RDDBlockId(4, 0))) + assert(listener.executorIdToStorageStatus("big").numBlocks === 2) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) + assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) // Task end with dropped blocks val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) @@ -112,17 +112,17 @@ class StorageStatusListenerSuite extends FunSuite { taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 1) - assert(listener.executorIdToStorageStatus("fat").blocks.size === 1) - assert(!listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").blocks.contains(RDDBlockId(4, 0))) + assert(listener.executorIdToStorageStatus("big").numBlocks === 1) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) + assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) + assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 1) - assert(listener.executorIdToStorageStatus("fat").blocks.size === 0) - assert(!listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 2))) - assert(listener.executorIdToStorageStatus("fat").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 1) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) + assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) + assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) } test("unpersist RDD") { @@ -137,16 +137,16 @@ class StorageStatusListenerSuite extends FunSuite { taskMetrics2.updatedBlocks = Some(Seq(block3)) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1)) listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics2)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 3) + assert(listener.executorIdToStorageStatus("big").numBlocks === 3) // Unpersist RDD listener.onUnpersistRDD(SparkListenerUnpersistRDD(9090)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 3) + assert(listener.executorIdToStorageStatus("big").numBlocks === 3) listener.onUnpersistRDD(SparkListenerUnpersistRDD(4)) - assert(listener.executorIdToStorageStatus("big").blocks.size === 2) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 1))) - assert(listener.executorIdToStorageStatus("big").blocks.contains(RDDBlockId(1, 2))) + assert(listener.executorIdToStorageStatus("big").numBlocks === 2) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) + assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) listener.onUnpersistRDD(SparkListenerUnpersistRDD(1)) - assert(listener.executorIdToStorageStatus("big").blocks.isEmpty) + assert(listener.executorIdToStorageStatus("big").numBlocks === 0) } } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala new file mode 100644 index 0000000000000..38678bbd1dd28 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.FunSuite + +/** + * Test various functionalities in StorageUtils and StorageStatus. + */ +class StorageSuite extends FunSuite { + private val memAndDisk = StorageLevel.MEMORY_AND_DISK + + // For testing add, update, and remove (for non-RDD blocks) + private def storageStatus1: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) + assert(status.blocks.isEmpty) + assert(status.rddBlocks.isEmpty) + assert(status.memUsed === 0L) + assert(status.memRemaining === 1000L) + assert(status.diskUsed === 0L) + assert(status.offHeapUsed === 0L) + status.addBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(TestBlockId("faa"), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status + } + + test("storage status add non-RDD blocks") { + val status = storageStatus1 + assert(status.blocks.size === 3) + assert(status.blocks.contains(TestBlockId("foo"))) + assert(status.blocks.contains(TestBlockId("fee"))) + assert(status.blocks.contains(TestBlockId("faa"))) + assert(status.rddBlocks.isEmpty) + assert(status.memUsed === 30L) + assert(status.memRemaining === 970L) + assert(status.diskUsed === 60L) + assert(status.offHeapUsed === 3L) + } + + test("storage status update non-RDD blocks") { + val status = storageStatus1 + status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L, 1L)) + status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L, 0L)) + assert(status.blocks.size === 3) + assert(status.memUsed === 160L) + assert(status.memRemaining === 840L) + assert(status.diskUsed === 140L) + assert(status.offHeapUsed === 2L) + } + + test("storage status remove non-RDD blocks") { + val status = storageStatus1 + status.removeBlock(TestBlockId("foo")) + status.removeBlock(TestBlockId("faa")) + assert(status.blocks.size === 1) + assert(status.blocks.contains(TestBlockId("fee"))) + assert(status.memUsed === 10L) + assert(status.memRemaining === 990L) + assert(status.diskUsed === 20L) + assert(status.offHeapUsed === 1L) + } + + // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks + private def storageStatus2: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) + assert(status.rddBlocks.isEmpty) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L, 0L)) + status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L, 0L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 100L, 200L, 1L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L, 0L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L, 0L)) + status + } + + test("storage status add RDD blocks") { + val status = storageStatus2 + assert(status.blocks.size === 7) + assert(status.rddBlocks.size === 5) + assert(status.rddBlocks.contains(RDDBlockId(0, 0))) + assert(status.rddBlocks.contains(RDDBlockId(1, 1))) + assert(status.rddBlocks.contains(RDDBlockId(2, 2))) + assert(status.rddBlocks.contains(RDDBlockId(2, 3))) + assert(status.rddBlocks.contains(RDDBlockId(2, 4))) + assert(status.rddBlocksById(0).size === 1) + assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) + assert(status.rddBlocksById(1).size === 1) + assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) + assert(status.rddBlocksById(2).size === 3) + assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) + assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) + assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) + assert(status.memUsedByRdd(0) === 10L) + assert(status.memUsedByRdd(1) === 100L) + assert(status.memUsedByRdd(2) === 30L) + assert(status.diskUsedByRdd(0) === 20L) + assert(status.diskUsedByRdd(1) === 200L) + assert(status.diskUsedByRdd(2) === 80L) + assert(status.offHeapUsedByRdd(0) === 1L) + assert(status.offHeapUsedByRdd(1) === 1L) + assert(status.offHeapUsedByRdd(2) === 1L) + assert(status.rddStorageLevel(0) === Some(memAndDisk)) + assert(status.rddStorageLevel(1) === Some(memAndDisk)) + assert(status.rddStorageLevel(2) === Some(memAndDisk)) + + // Verify default values for RDDs that don't exist + assert(status.rddBlocksById(10).isEmpty) + assert(status.memUsedByRdd(10) === 0L) + assert(status.diskUsedByRdd(10) === 0L) + assert(status.offHeapUsedByRdd(10) === 0L) + assert(status.rddStorageLevel(10) === None) + } + + test("storage status update RDD blocks") { + val status = storageStatus2 + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L, 0L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L, 0L)) + status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L, 0L)) + assert(status.blocks.size === 7) + assert(status.rddBlocks.size === 5) + assert(status.rddBlocksById(0).size === 1) + assert(status.rddBlocksById(1).size === 1) + assert(status.rddBlocksById(2).size === 3) + assert(status.memUsedByRdd(0) === 0L) + assert(status.memUsedByRdd(1) === 100L) + assert(status.memUsedByRdd(2) === 20L) + assert(status.diskUsedByRdd(0) === 0L) + assert(status.diskUsedByRdd(1) === 200L) + assert(status.diskUsedByRdd(2) === 1060L) + assert(status.offHeapUsedByRdd(0) === 0L) + assert(status.offHeapUsedByRdd(1) === 1L) + assert(status.offHeapUsedByRdd(2) === 0L) + } + + test("storage status remove RDD blocks") { + val status = storageStatus2 + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(1, 1)) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 4)) + assert(status.blocks.size === 3) + assert(status.rddBlocks.size === 2) + assert(status.rddBlocks.contains(RDDBlockId(0, 0))) + assert(status.rddBlocks.contains(RDDBlockId(2, 3))) + assert(status.rddBlocksById(0).size === 1) + assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) + assert(status.rddBlocksById(1).size === 0) + assert(status.rddBlocksById(2).size === 1) + assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) + assert(status.memUsedByRdd(0) === 10L) + assert(status.memUsedByRdd(1) === 0L) + assert(status.memUsedByRdd(2) === 10L) + assert(status.diskUsedByRdd(0) === 20L) + assert(status.diskUsedByRdd(1) === 0L) + assert(status.diskUsedByRdd(2) === 20L) + assert(status.offHeapUsedByRdd(0) === 1L) + assert(status.offHeapUsedByRdd(1) === 0L) + assert(status.offHeapUsedByRdd(2) === 0L) + } + + test("storage status containsBlock") { + val status = storageStatus2 + // blocks that actually exist + assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) + assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) + assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) + assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) + assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) + assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) + assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) + // blocks that don't exist + assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) + assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) + } + + test("storage status getBlock") { + val status = storageStatus2 + // blocks that actually exist + assert(status.blocks.get(TestBlockId("dan")) === status.getBlock(TestBlockId("dan"))) + assert(status.blocks.get(TestBlockId("man")) === status.getBlock(TestBlockId("man"))) + assert(status.blocks.get(RDDBlockId(0, 0)) === status.getBlock(RDDBlockId(0, 0))) + assert(status.blocks.get(RDDBlockId(1, 1)) === status.getBlock(RDDBlockId(1, 1))) + assert(status.blocks.get(RDDBlockId(2, 2)) === status.getBlock(RDDBlockId(2, 2))) + assert(status.blocks.get(RDDBlockId(2, 3)) === status.getBlock(RDDBlockId(2, 3))) + assert(status.blocks.get(RDDBlockId(2, 4)) === status.getBlock(RDDBlockId(2, 4))) + // blocks that don't exist + assert(status.blocks.get(TestBlockId("fan")) === status.getBlock(TestBlockId("fan"))) + assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) + } + + test("storage status num[Rdd]Blocks") { + val status = storageStatus2 + assert(status.blocks.size === status.numBlocks) + assert(status.rddBlocks.size === status.numRddBlocks) + status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) + assert(status.blocks.size === status.numBlocks) + assert(status.rddBlocks.size === status.numRddBlocks) + assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) + assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) + status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L, 400L)) + status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L, 100L)) + assert(status.blocks.size === status.numBlocks) + assert(status.rddBlocks.size === status.numRddBlocks) + assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) + assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) + assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) + status.removeBlock(RDDBlockId(4, 0)) + status.removeBlock(RDDBlockId(10, 10)) + assert(status.blocks.size === status.numBlocks) + assert(status.rddBlocks.size === status.numRddBlocks) + assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) + assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) + // remove a block that doesn't exist + status.removeBlock(RDDBlockId(1000, 999)) + assert(status.blocks.size === status.numBlocks) + assert(status.rddBlocks.size === status.numRddBlocks) + assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) + assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) + assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) + } + + test("storage status memUsed, diskUsed, tachyonUsed") { + val status = storageStatus2 + def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum + def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum + def actualOffHeapUsed: Long = status.blocks.values.map(_.tachyonSize).sum + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.offHeapUsed === actualOffHeapUsed) + status.addBlock(TestBlockId("fire"), BlockStatus(memAndDisk, 4000L, 5000L, 6000L)) + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L, 600L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L, 60L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.offHeapUsed === actualOffHeapUsed) + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L, 6L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L, 6L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L, 6L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.offHeapUsed === actualOffHeapUsed) + status.removeBlock(TestBlockId("fire")) + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 3)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.offHeapUsed === actualOffHeapUsed) + } + + // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations + private def stockStorageStatuses: Seq[StorageStatus] = { + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2, 2), 2000L) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3, 3), 3000L) + status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status2.addBlock(RDDBlockId(0, 3), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status2.addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status2.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status3.addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status3.addBlock(RDDBlockId(1, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) + Seq(status1, status2, status3) + } + + // For testing StorageUtils.updateRddInfo + private def stockRDDInfos: Seq[RDDInfo] = { + val info0 = new RDDInfo(0, "0", 10, memAndDisk) + val info1 = new RDDInfo(1, "1", 3, memAndDisk) + Seq(info0, info1) + } + + test("StorageUtils.updateRddInfo") { + val storageStatuses = stockStorageStatuses + val rddInfos = stockRDDInfos + StorageUtils.updateRddInfo(rddInfos, storageStatuses) + assert(rddInfos(0).storageLevel === memAndDisk) + assert(rddInfos(0).numCachedPartitions === 5) + assert(rddInfos(0).memSize === 5L) + assert(rddInfos(0).diskSize === 10L) + assert(rddInfos(0).tachyonSize === 0L) + assert(rddInfos(1).storageLevel === memAndDisk) + assert(rddInfos(1).numCachedPartitions === 3) + assert(rddInfos(1).memSize === 3L) + assert(rddInfos(1).diskSize === 6L) + assert(rddInfos(1).tachyonSize === 0L) + } + + test("StorageUtils.getRddBlockLocations") { + val storageStatuses = stockStorageStatuses + val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) + val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) + assert(blockLocations0.size === 5) + assert(blockLocations1.size === 3) + assert(blockLocations0.contains(RDDBlockId(0, 0))) + assert(blockLocations0.contains(RDDBlockId(0, 1))) + assert(blockLocations0.contains(RDDBlockId(0, 2))) + assert(blockLocations0.contains(RDDBlockId(0, 3))) + assert(blockLocations0.contains(RDDBlockId(0, 4))) + assert(blockLocations1.contains(RDDBlockId(1, 0))) + assert(blockLocations1.contains(RDDBlockId(1, 1))) + assert(blockLocations1.contains(RDDBlockId(1, 2))) + assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) + assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) + assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) + assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) + assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) + assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) + assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) + assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) + } + + test("StorageUtils.getRddBlockLocations with multiple locations") { + val storageStatuses = stockStorageStatuses + storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) + storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) + val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) + assert(blockLocations0.size === 5) + assert(blockLocations1.size === 3) + assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) + assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) + assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) + assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) + assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) + assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) + assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) + assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala new file mode 100644 index 0000000000000..6e68dcb3425aa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.Success +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.storage._ + +/** + * Test various functionality in the StorageListener that supports the StorageTab. + */ +class StorageTabSuite extends FunSuite with BeforeAndAfter { + private var bus: LiveListenerBus = _ + private var storageStatusListener: StorageStatusListener = _ + private var storageListener: StorageListener = _ + private val memAndDisk = StorageLevel.MEMORY_AND_DISK + private val memOnly = StorageLevel.MEMORY_ONLY + private val none = StorageLevel.NONE + private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) + private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly) + private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly) + private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk) + private def rddInfo3 = new RDDInfo(3, "grace", 400, memAndDisk) + private val bm1 = BlockManagerId("big", "dog", 1, 1) + + before { + bus = new LiveListenerBus + storageStatusListener = new StorageStatusListener + storageListener = new StorageListener(storageStatusListener) + bus.addListener(storageStatusListener) + bus.addListener(storageListener) + } + + test("stage submitted / completed") { + assert(storageListener._rddInfoMap.isEmpty) + assert(storageListener.rddInfoList.isEmpty) + + // 2 RDDs are known, but none are cached + val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0, rddInfo1), "details") + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + assert(storageListener._rddInfoMap.size === 2) + assert(storageListener.rddInfoList.isEmpty) + + // 4 RDDs are known, but only 2 are cached + val rddInfo2Cached = rddInfo2 + val rddInfo3Cached = rddInfo3 + rddInfo2Cached.numCachedPartitions = 1 + rddInfo3Cached.numCachedPartitions = 1 + val stageInfo1 = new StageInfo(1, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details") + bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + assert(storageListener._rddInfoMap.size === 4) + assert(storageListener.rddInfoList.size === 2) + + // Submitting RDDInfos with duplicate IDs does nothing + val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY) + rddInfo0Cached.numCachedPartitions = 1 + val stageInfo0Cached = new StageInfo(0, "0", 100, Seq(rddInfo0), "details") + bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached)) + assert(storageListener._rddInfoMap.size === 4) + assert(storageListener.rddInfoList.size === 2) + + // We only keep around the RDDs that are cached + bus.postToAll(SparkListenerStageCompleted(stageInfo0)) + assert(storageListener._rddInfoMap.size === 2) + assert(storageListener.rddInfoList.size === 2) + } + + test("unpersist") { + val rddInfo0Cached = rddInfo0 + val rddInfo1Cached = rddInfo1 + rddInfo0Cached.numCachedPartitions = 1 + rddInfo1Cached.numCachedPartitions = 1 + val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details") + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + assert(storageListener._rddInfoMap.size === 2) + assert(storageListener.rddInfoList.size === 2) + bus.postToAll(SparkListenerUnpersistRDD(0)) + assert(storageListener._rddInfoMap.size === 1) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerUnpersistRDD(4)) // doesn't exist + assert(storageListener._rddInfoMap.size === 1) + assert(storageListener.rddInfoList.size === 1) + bus.postToAll(SparkListenerUnpersistRDD(1)) + assert(storageListener._rddInfoMap.size === 0) + assert(storageListener.rddInfoList.size === 0) + } + + test("task end") { + val myRddInfo0 = rddInfo0 + val myRddInfo1 = rddInfo1 + val myRddInfo2 = rddInfo2 + val stageInfo0 = new StageInfo(0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + assert(storageListener._rddInfoMap.size === 3) + assert(storageListener.rddInfoList.size === 0) // not cached + assert(!storageListener._rddInfoMap(0).isCached) + assert(!storageListener._rddInfoMap(1).isCached) + assert(!storageListener._rddInfoMap(2).isCached) + + // Task end with no updated blocks. This should not change anything. + bus.postToAll(SparkListenerTaskEnd(0, "obliteration", Success, taskInfo, new TaskMetrics)) + assert(storageListener._rddInfoMap.size === 3) + assert(storageListener.rddInfoList.size === 0) + + // Task end with a few new persisted blocks, some from the same RDD + val metrics1 = new TaskMetrics + metrics1.updatedBlocks = Some(Seq( + (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L, 0L)), + (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L, 0L)), + (RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)), + (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L)) + )) + bus.postToAll(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo, metrics1)) + assert(storageListener._rddInfoMap(0).memSize === 800L) + assert(storageListener._rddInfoMap(0).diskSize === 400L) + assert(storageListener._rddInfoMap(0).tachyonSize === 200L) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 3) + assert(storageListener._rddInfoMap(0).isCached) + assert(storageListener._rddInfoMap(1).memSize === 0L) + assert(storageListener._rddInfoMap(1).diskSize === 240L) + assert(storageListener._rddInfoMap(1).tachyonSize === 0L) + assert(storageListener._rddInfoMap(1).numCachedPartitions === 1) + assert(storageListener._rddInfoMap(1).isCached) + assert(!storageListener._rddInfoMap(2).isCached) + assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) + + // Task end with a few dropped blocks + val metrics2 = new TaskMetrics + metrics2.updatedBlocks = Some(Seq( + (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L, 0L)), + (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L, 0L)), + (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist + (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist + )) + bus.postToAll(SparkListenerTaskEnd(2, "obliteration", Success, taskInfo, metrics2)) + assert(storageListener._rddInfoMap(0).memSize === 400L) + assert(storageListener._rddInfoMap(0).diskSize === 400L) + assert(storageListener._rddInfoMap(0).tachyonSize === 200L) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) + assert(storageListener._rddInfoMap(0).isCached) + assert(!storageListener._rddInfoMap(1).isCached) + assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) + assert(!storageListener._rddInfoMap(2).isCached) + assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) + } + +} From 148af6082cdb44840bbd61c7a4f67a95badad10b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 00:45:38 -0700 Subject: [PATCH 111/170] [SPARK-2454] Do not ship spark home to Workers When standalone Workers launch executors, they inherit the Spark home set by the driver. This means if the worker machines do not share the same directory structure as the driver node, the Workers will attempt to run scripts (e.g. bin/compute-classpath.sh) that do not exist locally and fail. This is a common scenario if the driver is launched from outside of the cluster. The solution is to simply not pass the driver's Spark home to the Workers. This PR further makes an attempt to avoid overloading the usages of `spark.home`, which is now only used for setting executor Spark home on Mesos and in python. This is based on top of #1392 and originally reported by YanTangZhai. Tested on standalone cluster. Author: Andrew Or Closes #1734 from andrewor14/spark-home-reprise and squashes the following commits: f71f391 [Andrew Or] Revert changes in python 1c2532c [Andrew Or] Merge branch 'master' of github.com:apache/spark into spark-home-reprise 188fc5d [Andrew Or] Avoid using spark.home where possible 09272b7 [Andrew Or] Always use Worker's working directory as spark home --- .../org/apache/spark/deploy/ApplicationDescription.scala | 1 - .../main/scala/org/apache/spark/deploy/JsonProtocol.scala | 1 - .../scala/org/apache/spark/deploy/client/TestClient.scala | 5 ++--- .../main/scala/org/apache/spark/deploy/worker/Worker.scala | 7 +++---- .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 +-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 2 +- .../scala/org/apache/spark/deploy/JsonProtocolSuite.scala | 5 ++--- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../apache/spark/deploy/worker/ExecutorRunnerTest.scala | 7 +++---- project/SparkBuild.scala | 2 +- python/pyspark/context.py | 2 +- repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala | 3 --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 1 - 13 files changed, 15 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 86305d2ea8a09..65a1a8fd7e929 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -22,7 +22,6 @@ private[spark] class ApplicationDescription( val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, - val sparkHome: Option[String], var appUiUrl: String, val eventLogDir: Option[String] = None) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index c4f5e294a393e..696f32a6f5730 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -56,7 +56,6 @@ private[spark] object JsonProtocol { ("cores" -> obj.maxCores) ~ ("memoryperslave" -> obj.memoryPerSlave) ~ ("user" -> obj.user) ~ - ("sparkhome" -> obj.sparkHome) ~ ("command" -> obj.command.toString) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index b8ffa9afb69cb..88a0862b96afe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -48,9 +48,8 @@ private[spark] object TestClient { val conf = new SparkConf val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) - val desc = new ApplicationDescription( - "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), - Seq(), Seq(), Seq()), Some("dummy-spark-home"), "ignored") + val desc = new ApplicationDescription("TestClient", Some(1), 512, + Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index fb5252da96519..c6ea42fceb659 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -81,7 +81,8 @@ private[spark] class Worker( @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() - val sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) + val sparkHome = + new File(sys.props.get("spark.test.home").orElse(sys.env.get("SPARK_HOME")).getOrElse(".")) var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] @@ -233,9 +234,7 @@ private[spark] class Worker( try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, - appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome), - workDir, akkaUrl, conf, ExecutorState.RUNNING) + self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 48aaaa54bdb35..a28446f6c8a6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -60,9 +60,8 @@ private[spark] class SparkDeploySchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) - val sparkHome = sc.getSparkHome() val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) + sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index de4bd90c8f7e5..e36902ec81e08 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ import scala.language.postfixOps class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { - val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get + val sparkHome = sys.props("spark.test.home") // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 093394ad6d142..31aa7ec837f43 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -89,7 +89,7 @@ class JsonProtocolSuite extends FunSuite { def createAppDesc(): ApplicationDescription = { val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, Some("sparkHome"), "appUiUrl") + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") } def createAppInfo() : ApplicationInfo = { @@ -169,8 +169,7 @@ object JsonConstants { val appDescJsonStr = """ |{"name":"name","cores":4,"memoryperslave":1234, - |"user":"%s","sparkhome":"sparkHome", - |"command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),List())"} + |"user":"%s","command":"Command(mainClass,List(arg1, arg2),Map(),List(),List(),List())"} """.format(System.getProperty("user.name", "")).stripMargin val executorRunnerJsonStr = diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9190b05e2dba2..8126ef1bb23aa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -295,7 +295,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { - val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get + val sparkHome = sys.props("spark.test.home") Utils.executeAndGetOutput( Seq("./bin/spark-submit") ++ args, new File(sparkHome), diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index ca4d987619c91..149a2b3d95b86 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -27,12 +27,11 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")) + val sparkHome = sys.props("spark.test.home") val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), - sparkHome, "appUiUrl") + Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") val appId = "12345-worker321-9876" - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")), + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a8bbd55861954..1d7cc6dd6aef3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -328,7 +328,7 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions in Test += "-Dspark.home=" + sparkHome, + javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7b0f8d83aedc5..2e80eb50f2207 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -84,7 +84,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, @param serializer: The serializer for RDDs. @param conf: A L{SparkConf} object setting Spark properties. @param gateway: Use an existing gateway and JVM, otherwise a new JVM - will be instatiated. + will be instantiated. >>> from pyspark.context import SparkContext diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 42c7e511dc3f5..65788f4646d91 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -969,9 +969,6 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, if (execUri != null) { conf.set("spark.executor.uri", execUri) } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } sparkContext = new SparkContext(conf) logInfo("Created spark context..") sparkContext diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ac56ff709c1c4..b780282bdac37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -35,7 +35,6 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName - val sparkHome = ssc.sc.getSparkHome.getOrElse(null) val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir From 08c095b6647033285e8f6703922bdacecce3fc71 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Sat, 2 Aug 2014 00:48:17 -0700 Subject: [PATCH 112/170] [SPARK-1812] sql/catalyst - Provide explicit type information For Scala 2.11 compatibility. Without the explicit type specification, withNullability return type is inferred to be Attribute, and thus calling at() on the returned object fails in these tests: [ERROR] /Users/avati/work/spark/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala:370: value at is not a [ERROR] val c4_notNull = 'a.boolean.notNull.at(3) [ERROR] ^ [ERROR] /Users/avati/work/spark/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala:371: value at is not a [ERROR] val c5_notNull = 'a.boolean.notNull.at(4) [ERROR] ^ [ERROR] /Users/avati/work/spark/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala:372: value at is not a [ERROR] val c6_notNull = 'a.boolean.notNull.at(5) [ERROR] ^ [ERROR] /Users/avati/work/spark/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala:558: value at is not a [ERROR] val s_notNull = 'a.string.notNull.at(0) Signed-off-by: Anand Avati Author: Anand Avati Closes #1709 from avati/SPARK-1812-notnull and squashes the following commits: 0470eb3 [Anand Avati] SPARK-1812: sql/catalyst - Provide explicit type information --- .../spark/sql/catalyst/expressions/namedExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index ed69928ae9eb8..02d04762629f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -134,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with changed nullability. */ - override def withNullability(newNullability: Boolean) = { + override def withNullability(newNullability: Boolean): AttributeReference = { if (nullable == newNullability) { this } else { From 25cad6adf6479fb00265df06d5f77599f8defd26 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 2 Aug 2014 00:57:47 -0700 Subject: [PATCH 113/170] HOTFIX: Fixing test error in maven for flume-sink. We needed to add an explicit dependency on scalatest since this module will not get it from spark core like others do. --- external/flume-sink/pom.xml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index d11129ce8d89d..d0bf1cf1ea796 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -67,7 +67,10 @@ org.scala-lang scala-library - 2.10.4 + + + org.scalatest + scalatest_${scala.binary.version}
    From 44460ba594fbfe5a6ee66e5121ead914bf16f9f6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 2 Aug 2014 01:11:03 -0700 Subject: [PATCH 114/170] HOTFIX: Fix concurrency issue in FlumePollingStreamSuite. This has been failing on master. One possible cause is that the port gets contended if multiple test runs happen concurrently and they hit this test at the same time. Since this test takes a long time (60 seconds) that's very plausible. This patch randomizes the port used in this test to avoid contention. --- .../spark/streaming/flume/FlumePollingStreamSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 47071d0cc4714..27bf2ac962721 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} +import java.util.Random import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -37,13 +38,16 @@ import org.apache.spark.streaming.flume.sink._ class FlumePollingStreamSuite extends TestSuiteBase { - val testPort = 9999 + val random = new Random() + /** Return a port in the ephemeral range. */ + def getTestPort = random.nextInt(16382) + 49152 val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 test("flume polling test") { + val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = @@ -77,6 +81,7 @@ class FlumePollingStreamSuite extends TestSuiteBase { } test("flume polling test multiple hosts") { + val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) From 87738bfa4051771ddfb8c4a4c1eb142fd77e3a46 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 2 Aug 2014 01:26:16 -0700 Subject: [PATCH 115/170] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #706 (close requested by 'pwendell') Closes #453 (close requested by 'pwendell') Closes #557 (close requested by 'tdas') Closes #495 (close requested by 'tdas') Closes #1232 (close requested by 'pwendell') Closes #82 (close requested by 'pwendell') Closes #600 (close requested by 'pwendell') Closes #473 (close requested by 'pwendell') Closes #351 (close requested by 'pwendell') From e09e18b3123c20e9b9497cf606473da500349d4d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 12:11:50 -0700 Subject: [PATCH 116/170] [HOTFIX] Do not throw NPE if spark.test.home is not set `spark.test.home` was introduced in #1734. This is fine for SBT but is failing maven tests. Either way it shouldn't throw an NPE. Author: Andrew Or Closes #1739 from andrewor14/fix-spark-test-home and squashes the following commits: ce2624c [Andrew Or] Do not throw NPE if spark.test.home is not set --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 9 +++++++-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../apache/spark/deploy/worker/ExecutorRunnerTest.scala | 2 +- pom.xml | 8 ++++---- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c6ea42fceb659..458d9947bd873 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,7 +71,7 @@ private[spark] class Worker( // TTL for app folders/data; after TTL expires it will be cleaned up val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - + val testing: Boolean = sys.props.contains("spark.testing") val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null @@ -82,7 +82,12 @@ private[spark] class Worker( @volatile var connected = false val workerId = generateWorkerId() val sparkHome = - new File(sys.props.get("spark.test.home").orElse(sys.env.get("SPARK_HOME")).getOrElse(".")) + if (testing) { + assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") + new File(sys.props("spark.test.home")) + } else { + new File(sys.env.get("SPARK_HOME").getOrElse(".")) + } var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e36902ec81e08..a73e1ef0288a5 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ import scala.language.postfixOps class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 8126ef1bb23aa..a5cdcfb5de03b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -295,7 +295,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) Utils.executeAndGetOutput( Seq("./bin/spark-submit") ++ args, new File(sparkHome), diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 149a2b3d95b86..39ab53cf0b5b1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") val appId = "12345-worker321-9876" diff --git a/pom.xml b/pom.xml index ae97bf03c53a2..99ae4b8b33f94 100644 --- a/pom.xml +++ b/pom.xml @@ -868,10 +868,10 @@ ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - ${session.executionRootDirectory} - 1 - + + ${session.executionRootDirectory} + 1 + From 3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 13:07:17 -0700 Subject: [PATCH 117/170] [SPARK-2478] [mllib] DecisionTree Python API Added experimental Python API for Decision Trees. API: * class DecisionTreeModel ** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints ** numNodes() ** depth() ** __str__() * class DecisionTree ** trainClassifier() ** trainRegressor() ** train() Examples and testing: * Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py * Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses. CC mengxr manishamde Author: Joseph K. Bradley Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits: 3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. 6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. 67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree fa10ea7 [Joseph K. Bradley] Small style update 7968692 [Joseph K. Bradley] small braces typo fix e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new 6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. 93953f1 [Joseph K. Bradley] Likely done with Python API. 6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API 188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more. 2b20c61 [Joseph K. Bradley] Small doc and style updates 1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel. --- .../main/python/mllib/decision_tree_runner.py | 133 +++++++++++ .../main/python/mllib/logistic_regression.py | 4 +- .../mllib/api/python/PythonMLLibAPI.scala | 78 ++++++ .../mllib/tree/configuration/Strategy.scala | 3 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 3 +- python/pyspark/mllib/_common.py | 33 ++- python/pyspark/mllib/tests.py | 36 +++ python/pyspark/mllib/tree.py | 225 ++++++++++++++++++ python/pyspark/mllib/util.py | 14 +- python/run-tests | 1 + 10 files changed, 509 insertions(+), 21 deletions(-) create mode 100755 examples/src/main/python/mllib/decision_tree_runner.py create mode 100644 python/pyspark/mllib/tree.py diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py new file mode 100755 index 0000000000000..8efadb5223f56 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision tree classification and regression using MLlib. +""" + +import numpy, os, sys + +from operator import add + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + + +def getAccuracy(dtModel, data): + """ + Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + (x[0] == x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainCorrect / (0.0 + data.count()) + + +def getMSE(dtModel, data): + """ + Return mean squared error (MSE) of DecisionTreeModel on the given + RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainMSE / (0.0 + data.count()) + + +def reindexClassLabels(data): + """ + Re-index class labels in a dataset to the range {0,...,numClasses-1}. + If all labels in that range already appear at least once, + then the returned RDD is the same one (without a mapping). + Note: If a label simply does not appear in the data, + the index will not include it. + Be aware of this when reindexing subsampled data. + :param data: RDD of LabeledPoint where labels are integer values + denoting labels for a classification problem. + :return: Pair (reindexedData, origToNewLabels) where + reindexedData is an RDD of LabeledPoint with labels in + the range {0,...,numClasses-1}, and + origToNewLabels is a dictionary mapping original labels + to new labels. + """ + # classCounts: class --> # examples in class + classCounts = data.map(lambda x: x.label).countByValue() + numExamples = sum(classCounts.values()) + sortedClasses = sorted(classCounts.keys()) + numClasses = len(classCounts) + # origToNewLabels: class --> index in 0,...,numClasses-1 + if (numClasses < 2): + print >> sys.stderr, \ + "Dataset for classification should have at least 2 classes." + \ + " The given dataset had only %d classes." % numClasses + exit(1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) + + print "numClasses = %d" % numClasses + print "Per-class example fractions, counts:" + print "Class\tFrac\tCount" + for c in sortedClasses: + frac = classCounts[c] / (numExamples + 0.0) + print "%g\t%g\t%d" % (c, frac, classCounts[c]) + + if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): + return (data, origToNewLabels) + else: + reindexedData = \ + data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) + return (reindexedData, origToNewLabels) + + +def usage(): + print >> sys.stderr, \ + "Usage: decision_tree_runner [libsvm format data filepath]\n" + \ + " Note: This only supports binary classification." + exit(1) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + usage() + sc = SparkContext(appName="PythonDT") + + # Load data. + dataPath = 'data/mllib/sample_libsvm_data.txt' + if len(sys.argv) == 2: + dataPath = sys.argv[1] + if not os.path.isfile(dataPath): + usage() + points = MLUtils.loadLibSVMFile(sc, dataPath) + + # Re-index class labels if needed. + (reindexedData, origToNewLabels) = reindexClassLabels(points) + + # Train a classifier. + model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + # Print learned tree and stats. + print "Trained DecisionTree for classification:" + print " Model numNodes: %d\n" % model.numNodes() + print " Model depth: %d\n" % model.depth() + print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) + print model diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 6e0f7a4ee5a81..9d547ff77c984 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -30,8 +30,10 @@ from pyspark.mllib.classification import LogisticRegressionWithSGD -# Parse a line of text into an MLlib LabeledPoint object def parsePoint(line): + """ + Parse a line of text into an MLlib LabeledPoint object. + """ values = [float(s) for s in line.split(' ')] if values[0] == -1: # Convert -1 labels to 0 for MLlib values[0] = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 7d912737b8f0b..1d5d3762ed8e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ @@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils @@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib DecisionTree.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + * @param dataBytesJRDD Training data + * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + */ + def trainDecisionTreeModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + impurityStr: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + + val algo: Algo = algoStr match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") + } + val impurity: Impurity = impurityStr match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") + } + + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + + DecisionTree.train(data, strategy) + } + + /** + * Predict the label of the given data point. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param featuresBytes Serialized feature vector for data point + * @return predicted label + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + featuresBytes: Array[Byte]): Double = { + val features: Vector = deserializeDoubleVector(featuresBytes) + model.predict(features) + } + + /** + * Predict the labels of the given data points. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param dataJRDD A JavaRDD with serialized feature vectors + * @return JavaRDD of serialized predictions + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + model.predict(data).map(serializeDouble) + } + /** * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). * Returns the correlation matrix serialized into a byte array understood by deserializers in @@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable { val s = getSeedOrDefault(seed) RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 5c65b537b6867..fdad4f029aa99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -56,7 +56,8 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } - val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassClassification = + algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 546a132559326..8665a00f3b356 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) + val err = prediction - expected.label + err * err }.sum val mse = squaredError / input.length assert(mse <= requiredMSE) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c6ca6a75df746..9c1565affbdac 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype): temp_array[...] = array -def _get_unmangled_rdd(data, serializer): +def _get_unmangled_rdd(data, serializer, cache=True): + """ + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ dataBytes = data.map(serializer) dataBytes._bypass_serializer = True - dataBytes.cache() # TODO: users should unpersist() this later! + if cache: + dataBytes.cache() return dataBytes -# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) +def _get_unmangled_double_vector_rdd(data, cache=True): + """ + Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of + _serialized_double_vectors. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_double_vector, cache) -# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points -def _get_unmangled_labeled_point_rdd(data): - return _get_unmangled_rdd(data, _serialize_labeled_point) +def _get_unmangled_labeled_point_rdd(data, cache=True): + """ + Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_labeled_point, cache) # Common functions for dealing with and training linear models @@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs): if x.size != coeffs.shape[0]: raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( x.size, coeffs.shape[0])) - elif (type(x) == RDD): + elif isinstance(x, RDD): raise RuntimeError("Bulk predict not yet supported.") else: raise TypeError("Argument of type " + type(x).__name__ + " unsupported") diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 37ccf1d590743..9d1e5be637a9a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -100,6 +100,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -127,9 +128,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -157,6 +168,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -229,6 +248,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -256,9 +276,18 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -286,6 +315,13 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + if __name__ == "__main__": if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py new file mode 100644 index 0000000000000..1e0006df75ac6 --- /dev/null +++ b/python/pyspark/mllib/tree.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter + +from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ + _deserialize_double +from pyspark.mllib.regression import LabeledPoint +from pyspark.serializers import NoOpSerializer + +class DecisionTreeModel(object): + """ + A decision tree model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + """ + + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def predict(self, x): + """ + Predict the label of one or more examples. + :param x: Data point (feature vector), + or an RDD of data points (feature vectors). + """ + pythonAPI = self._sc._jvm.PythonMLLibAPI() + if isinstance(x, RDD): + # Bulk prediction + if x.count() == 0: + return self._sc.parallelize([]) + dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) + jSerializedPreds = \ + pythonAPI.predictDecisionTreeModel(self._java_model, + dataBytes._jrdd) + serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) + return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) + else: + # Assume x is a single data point. + x_ = _serialize_double_vector(x) + return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + + def numNodes(self): + return self._java_model.numNodes() + + def depth(self): + return self._java_model.depth() + + def __str__(self): + return self._java_model.toString() + + +class DecisionTree(object): + """ + Learning algorithm for a decision tree model + for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + + Example usage: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) + >>> print(model) + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True + """ + + @staticmethod + def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + impurity="gini", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for classification. + + :param data: Training data: RDD of LabeledPoint. + Labels are integers {0,1,...,numClasses}. + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "entropy" or "gini" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "classification", numClasses, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + @staticmethod + def trainRegressor(data, categoricalFeaturesInfo={}, + impurity="variance", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for regression. + + :param data: Training data: RDD of LabeledPoint. + Labels are real numbers. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "variance" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "regression", 0, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + + @staticmethod + def train(data, algo, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins=100): + """ + Train a DecisionTreeModel for classification or regression. + + :param data: Training data: RDD of LabeledPoint. + For classification, labels are integers + {0,1,...,numClasses}. + For regression, labels are real numbers. + :param algo: "classification" or "regression" + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: For classification: "entropy" or "gini". + For regression: "variance". + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, algo, + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index d94900cefdb77..639cda6350229 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,6 +16,7 @@ # import numpy as np +import warnings from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint @@ -29,9 +30,9 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ - @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + warnings.warn("deprecated", DeprecationWarning) return _parse_libsvm_line(line) @staticmethod @@ -67,9 +68,9 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) - @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + warnings.warn("deprecated", DeprecationWarning) return loadLibSVMFile(sc, path, numFeatures, minPartitions) @staticmethod @@ -106,7 +107,6 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -115,20 +115,18 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> type(examples[1]) == LabeledPoint True >>> print examples[1] - (0.0,(6,[],[])) + (-1.0,(6,[],[])) >>> type(examples[2]) == LabeledPoint True >>> print examples[2] - (0.0,(6,[1,3,5],[4.0,5.0,6.0])) - >>> multiclass_examples[1].label - -1.0 + (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() - numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod diff --git a/python/run-tests b/python/run-tests index 5049e15ce5f8a..48feba2f5bd63 100755 --- a/python/run-tests +++ b/python/run-tests @@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green From 67bd8e3c217a80c3117a6e3853aa60fe13d08c91 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 13:16:41 -0700 Subject: [PATCH 118/170] [SQL] Set outputPartitioning of BroadcastHashJoin correctly. I think we will not generate the plan triggering this bug at this moment. But, let me explain it... Right now, we are using `left.outputPartitioning` as the `outputPartitioning` of a `BroadcastHashJoin`. We may have a wrong physical plan for cases like... ```sql SELECT l.key, count(*) FROM (SELECT key, count(*) as cnt FROM src GROUP BY key) l // This is buildPlan JOIN r // This is the streamedPlan ON (l.cnt = r.value) GROUP BY l.key ``` Let's say we have a `BroadcastHashJoin` on `l` and `r`. For this case, we will pick `l`'s `outputPartitioning` for the `outputPartitioning`of the `BroadcastHashJoin` on `l` and `r`. Also, because the last `GROUP BY` is using `l.key` as the key, we will not introduce an `Exchange` for this aggregation. However, `r`'s outputPartitioning may not match the required distribution of the last `GROUP BY` and we fail to group data correctly. JIRA is being reindexed. I will create a JIRA ticket once it is back online. Author: Yin Huai Closes #1735 from yhuai/BroadcastHashJoin and squashes the following commits: 96d9cb3 [Yin Huai] Set outputPartitioning correctly. --- .../src/main/scala/org/apache/spark/sql/execution/joins.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index cc138c749949d..51bb61530744c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -405,8 +405,7 @@ case class BroadcastHashJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil From 91f9504e6086fac05b40545099f9818949c24bca Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Sat, 2 Aug 2014 13:35:35 -0700 Subject: [PATCH 119/170] [SPARK-1981] Add AWS Kinesis streaming support Author: Chris Fregly Closes #1434 from cfregly/master and squashes the following commits: 4774581 [Chris Fregly] updated docs, renamed retry to retryRandom to be more clear, removed retries around store() method 0393795 [Chris Fregly] moved Kinesis examples out of examples/ and back into extras/kinesis-asl 691a6be [Chris Fregly] fixed tests and formatting, fixed a bug with JavaKinesisWordCount during union of streams 0e1c67b [Chris Fregly] Merge remote-tracking branch 'upstream/master' 74e5c7c [Chris Fregly] updated per TD's feedback. simplified examples, updated docs e33cbeb [Chris Fregly] Merge remote-tracking branch 'upstream/master' bf614e9 [Chris Fregly] per matei's feedback: moved the kinesis examples into the examples/ dir d17ca6d [Chris Fregly] per TD's feedback: updated docs, simplified the KinesisUtils api 912640c [Chris Fregly] changed the foundKinesis class to be a publically-avail class db3eefd [Chris Fregly] Merge remote-tracking branch 'upstream/master' 21de67f [Chris Fregly] Merge remote-tracking branch 'upstream/master' 6c39561 [Chris Fregly] parameterized the versions of the aws java sdk and kinesis client 338997e [Chris Fregly] improve build docs for kinesis 828f8ae [Chris Fregly] more cleanup e7c8978 [Chris Fregly] Merge remote-tracking branch 'upstream/master' cd68c0d [Chris Fregly] fixed typos and backward compatibility d18e680 [Chris Fregly] Merge remote-tracking branch 'upstream/master' b3b0ff1 [Chris Fregly] [SPARK-1981] Add AWS Kinesis streaming support --- bin/run-example | 3 +- bin/run-example2.cmd | 3 +- dev/audit-release/audit_release.py | 4 +- .../src/main/scala/SparkApp.scala | 7 + dev/audit-release/sbt_app_kinesis/build.sbt | 28 ++ .../src/main/scala/SparkApp.scala | 33 +++ dev/create-release/create-release.sh | 4 +- dev/run-tests | 3 + docs/streaming-custom-receivers.md | 4 +- docs/streaming-kinesis.md | 58 ++++ docs/streaming-programming-guide.md | 12 +- examples/pom.xml | 13 + extras/kinesis-asl/pom.xml | 96 ++++++ .../streaming/JavaKinesisWordCountASL.java | 180 ++++++++++++ .../src/main/resources/log4j.properties | 37 +++ .../streaming/KinesisWordCountASL.scala | 251 ++++++++++++++++ .../kinesis/KinesisCheckpointState.scala | 56 ++++ .../streaming/kinesis/KinesisReceiver.scala | 149 ++++++++++ .../kinesis/KinesisRecordProcessor.scala | 212 ++++++++++++++ .../streaming/kinesis/KinesisUtils.scala | 96 ++++++ .../kinesis/JavaKinesisStreamSuite.java | 41 +++ .../src/test/resources/log4j.properties | 26 ++ .../kinesis/KinesisReceiverSuite.scala | 275 ++++++++++++++++++ pom.xml | 10 + project/SparkBuild.scala | 6 +- 25 files changed, 1592 insertions(+), 15 deletions(-) create mode 100644 dev/audit-release/sbt_app_kinesis/build.sbt create mode 100644 dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala create mode 100644 docs/streaming-kinesis.md create mode 100644 extras/kinesis-asl/pom.xml create mode 100644 extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java create mode 100644 extras/kinesis-asl/src/main/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala create mode 100644 extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java create mode 100644 extras/kinesis-asl/src/test/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala diff --git a/bin/run-example b/bin/run-example index 942706d733122..68a35702eddd3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -29,7 +29,8 @@ if [ -n "$1" ]; then else echo "Usage: ./bin/run-example [example-args]" 1>&2 echo " - set MASTER=XX to use a specific master" 1>&2 - echo " - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression)" 1>&2 + echo " - can use abbreviated example class name relative to com.apache.spark.examples" 1>&2 + echo " (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL)" 1>&2 exit 1 fi diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index eadedd7fa61ff..b29bf90c64e90 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -32,7 +32,8 @@ rem Test that an argument was given if not "x%1"=="x" goto arg_given echo Usage: run-example ^ [example-args] echo - set MASTER=XX to use a specific master - echo - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression) + echo - can use abbreviated example class name relative to com.apache.spark.examples + echo (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL) goto exit :arg_given diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 230e900ecd4de..16ea1a71290dc 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -105,7 +105,7 @@ def get_url(url): "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", - "spark-catalyst", "spark-sql", "spark-hive" + "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) @@ -136,7 +136,7 @@ def ensure_path_not_present(x): os.chdir(original_dir) # SBT application tests -for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive"]: +for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive", "sbt_app_kinesis"]: os.chdir(app) ret = run_cmd("sbt clean run", exit_on_failure=False) test(ret == 0, "sbt application (%s)" % app) diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index 77bbd167b199a..fc03fec9866a6 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -50,5 +50,12 @@ object SimpleApp { println("Ganglia sink was loaded via spark-core") System.exit(-1) } + + // Remove kinesis from default build due to ASL license issue + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (foundKinesis) { + println("Kinesis was loaded via spark-core") + System.exit(-1) + } } } diff --git a/dev/audit-release/sbt_app_kinesis/build.sbt b/dev/audit-release/sbt_app_kinesis/build.sbt new file mode 100644 index 0000000000000..981bc7957b5ed --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/build.sbt @@ -0,0 +1,28 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +name := "Kinesis Test" + +version := "1.0" + +scalaVersion := System.getenv.get("SCALA_VERSION") + +libraryDependencies += "org.apache.spark" %% "spark-streaming-kinesis-asl" % System.getenv.get("SPARK_VERSION") + +resolvers ++= Seq( + "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala new file mode 100644 index 0000000000000..9f85066501472 --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main.scala + +import scala.util.Try + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +object SimpleApp { + def main(args: Array[String]) { + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (!foundKinesis) { + println("Kinesis not loaded via kinesis-asl") + System.exit(-1) + } + } +} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index af46572e6602b..42473629d4f15 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,15 +53,15 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ release:perform cd .. diff --git a/dev/run-tests b/dev/run-tests index daa85bc750c07..d401c90f41d7b 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,6 +36,9 @@ fi if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi + +export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" + echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" # Remove work directory diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a2dc3a8961dfc..1e045a3dd0ca9 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, files, sockets, etc.). +the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. @@ -174,7 +174,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
    diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md new file mode 100644 index 0000000000000..801c905c88df8 --- /dev/null +++ b/docs/streaming-kinesis.md @@ -0,0 +1,58 @@ +--- +layout: global +title: Spark Streaming Kinesis Receiver +--- + +### Kinesis +Build notes: +
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • +
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • +
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • +
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • + +Kinesis examples notes: +
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • +
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • +
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • + +Deployment and runtime notes: +
  • A single KinesisReceiver can process many shards of a stream.
  • +
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream.
  • +
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • +
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    + 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    + 2) Java System Properties - aws.accessKeyId and aws.secretKey
    + 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    + 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    +
  • +
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    + http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, +retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • +
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). +Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, +it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • + +Failure recovery notes: +
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    + 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    + 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    + 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +
  • +
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • +
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) +or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data +depending on the checkpoint frequency.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • +
  • Record processing should be idempotent when possible.
  • +
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • +
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7b8b7933434c4..9f331ed50d2a4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -9,7 +9,7 @@ title: Spark Streaming Programming Guide # Overview Spark Streaming is an extension of the core Spark API that allows enables high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ or plain old TCP sockets and be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's in-built @@ -38,7 +38,7 @@ stream of results in batches. Spark Streaming provides a high-level abstraction called *discretized stream* or *DStream*, which represents a continuous stream of data. DStreams can be created either from input data -stream from sources such as Kafka and Flume, or by applying high-level +stream from sources such as Kafka, Flume, and Kinesis, or by applying high-level operations on other DStreams. Internally, a DStream is represented as a sequence of [RDDs](api/scala/index.html#org.apache.spark.rdd.RDD). @@ -313,7 +313,7 @@ To write your own Spark Streaming program, you will have to add the following de artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION}} -For ingesting data from sources like Kafka and Flume that are not present in the Spark +For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core API, you will have to add the corresponding artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, @@ -327,6 +327,7 @@ some of the common ones are as follows. Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} + Kinesis
    (built separately) kinesis-asl_{{site.SCALA_BINARY_VERSION}} @@ -442,7 +443,7 @@ see the API documentations of the relevant functions in Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) for Java. -Additional functionality for creating DStreams from sources such as Kafka, Flume, and Twitter +Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter can be imported by adding the right dependencies as explained in an [earlier](#linking) section. To take the case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the @@ -467,6 +468,9 @@ For more details on these additional sources, see the corresponding [API documen Furthermore, you can also implement your own custom receiver for your sources. See the [Custom Receiver Guide](streaming-custom-receivers.html). +### Kinesis +[Kinesis](streaming-kinesis.html) + ## Operations There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams diff --git a/examples/pom.xml b/examples/pom.xml index c4ed0f5a6a02b..8c4c128bb484d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,19 @@ Spark Project Examples http://spark.apache.org/ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + + + org.apache.spark diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml new file mode 100644 index 0000000000000..a54b34235dfb4 --- /dev/null +++ b/extras/kinesis-asl/pom.xml @@ -0,0 +1,96 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-streaming-kinesis-asl_2.10 + jar + Spark Kinesis Integration + + + kinesis-asl + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + com.amazonaws + amazon-kinesis-client + ${aws.kinesis.client.version} + + + com.amazonaws + aws-java-sdk + ${aws.java.sdk.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymockclassextension + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java new file mode 100644 index 0000000000000..a8b907b241893 --- /dev/null +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.streaming; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import org.apache.log4j.Logger; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.kinesis.KinesisUtils; + +import scala.Tuple2; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.kinesis.AmazonKinesisClient; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.google.common.collect.Lists; + +/** + * Java-friendly Kinesis Spark Streaming WordCount example + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details + * on the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: JavaKinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data + * onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + */ +public final class JavaKinesisWordCountASL { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + /* Make the constructor private to enforce singleton */ + private JavaKinesisWordCountASL() { + } + + public static void main(String[] args) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + "|Usage: KinesisWordCount \n" + + "| is the name of the Kinesis stream\n" + + "| is the endpoint of the Kinesis service\n" + + "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + /* Populate the appropriate variables from the given args */ + String streamName = args[0]; + String endpointUrl = args[1]; + /* Set the batch interval to a fixed 2000 millis (2 seconds) */ + Duration batchInterval = new Duration(2000); + + /* Create a Kinesis client in order to determine the number of shards for the given stream */ + AmazonKinesisClient kinesisClient = new AmazonKinesisClient( + new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + + /* Determine the number of shards from the stream */ + int numShards = kinesisClient.describeStream(streamName) + .getStreamDescription().getShards().size(); + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ + int numStreams = numShards; + + /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ + int numSparkThreads = numStreams + 1; + + /* Setup the Spark config. */ + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( + "local[" + numSparkThreads + "]"); + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + Duration checkpointInterval = batchInterval; + + /* Setup the StreamingContext */ + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) + ); + } + + /* Union all the streams if there is more than 1 stream */ + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + /* Otherwise, just use the 1 stream */ + unionStreams = streamsList.get(0); + } + + /* + * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. + * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. + */ + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + /* Print the first 10 wordCounts */ + wordCounts.print(); + + /* Start the streaming context and await termination */ + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties new file mode 100644 index 0000000000000..97348fb5b6123 --- /dev/null +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +log4j.rootCategory=WARN, console + +# File appender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Console appender +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.out +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala new file mode 100644 index 0000000000000..d03edf8b30a9f --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import java.nio.ByteBuffer +import scala.util.Random +import org.apache.spark.Logging +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.PutRecordRequest +import org.apache.log4j.Logger +import org.apache.log4j.Level + +/** + * Kinesis Spark Streaming WordCount example. + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on + * the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: KinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class below called KinesisWordCountProducerASL which puts + * dummy data onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + */ +object KinesisWordCountASL extends Logging { + def main(args: Array[String]) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + """ + |Usage: KinesisWordCount + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(streamName, endpointUrl) = args + + /* Determine the number of shards from the stream */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpointUrl) + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() + .size() + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + val numStreams = numShards + + /* + * numSparkThreads should be 1 more thread than the number of receivers. + * This leaves one thread available for actually processing the data. + */ + val numSparkThreads = numStreams + 1 + + /* Setup the and SparkConfig and StreamingContext */ + /* Spark Streaming batch interval */ + val batchInterval = Milliseconds(2000) + val sparkConfig = new SparkConf().setAppName("KinesisWordCount") + .setMaster(s"local[$numSparkThreads]") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + val kinesisCheckpointInterval = batchInterval + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + val kinesisStreams = (0 until numStreams).map { i => + KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + } + + /* Union all the streams */ + val unionStreams = ssc.union(kinesisStreams) + + /* Convert each line of Array[Byte] to String, split into words, and count them */ + val words = unionStreams.flatMap(byteArray => new String(byteArray) + .split(" ")) + + /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) + + /* Print the first 10 wordCounts */ + wordCounts.print() + + /* Start the streaming context and await termination */ + ssc.start() + ssc.awaitTermination() + } +} + +/** + * Usage: KinesisWordCountProducerASL + * + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * is the rate of records per second to put onto the stream + * is the rate of records per second to put onto the stream + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com 10 5 + */ +object KinesisWordCountProducerASL { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: KinesisWordCountProducerASL " + + " ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args + + /* Generate the records and return the totals */ + val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + + /* Print the array of (index, total) tuples */ + println("Totals") + totals.foreach(total => println(total.toString())) + } + + def generate(stream: String, + endpoint: String, + recordsPerSecond: Int, + wordsPerRecord: Int): Seq[(Int, Int)] = { + + val MaxRandomInts = 10 + + /* Create the Kinesis client */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpoint) + + println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + + s" $recordsPerSecond records per second and $wordsPerRecord words per record"); + + val totals = new Array[Int](MaxRandomInts) + /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ + for (i <- 1 to 5) { + + /* Generate recordsPerSec records to put onto the stream */ + val records = (1 to recordsPerSecond.toInt).map { recordNum => + /* + * Randomly generate each wordsPerRec words between 0 (inclusive) + * and MAX_RANDOM_INTS (exclusive) + */ + val data = (1 to wordsPerRecord.toInt).map(x => { + /* Generate the random int */ + val randomInt = Random.nextInt(MaxRandomInts) + + /* Keep track of the totals */ + totals(randomInt) += 1 + + randomInt.toString() + }).mkString(" ") + + /* Create a partitionKey based on recordNum */ + val partitionKey = s"partitionKey-$recordNum" + + /* Create a PutRecordRequest with an Array[Byte] version of the data */ + val putRecordRequest = new PutRecordRequest().withStreamName(stream) + .withPartitionKey(partitionKey) + .withData(ByteBuffer.wrap(data.getBytes())); + + /* Put the record onto the stream and capture the PutRecordResult */ + val putRecordResult = kinesisClient.putRecord(putRecordRequest); + } + + /* Sleep for a second */ + Thread.sleep(1000) + println("Sent " + recordsPerSecond + " records") + } + + /* Convert the totals to (index, total) tuple */ + (0 to (MaxRandomInts - 1)).zip(totals) + } +} + +/** + * Utility functions for Spark Streaming examples. + * This has been lifted from the examples/ project to remove the circular dependency. + */ +object StreamingExamples extends Logging { + + /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + def setStreamingLogLevels() { + val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + // We first log something to initialize Spark's default logging, then we override the + // logging level. + logInfo("Setting log level to [WARN] for streaming example." + + " To override add a custom log4j.properties to the classpath.") + Logger.getRootLogger.setLevel(Level.WARN) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala new file mode 100644 index 0000000000000..0b80b611cdce7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.util.SystemClock + +/** + * This is a helper class for managing checkpoint clocks. + * + * @param checkpointInterval + * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) + */ +private[kinesis] class KinesisCheckpointState( + checkpointInterval: Duration, + currentClock: Clock = new SystemClock()) + extends Logging { + + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ + val checkpointClock = new ManualClock() + checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + + /** + * Check if it's time to checkpoint based on the current time and the derived time + * for the next checkpoint + * + * @return true if it's time to checkpoint + */ + def shouldCheckpoint(): Boolean = { + new SystemClock().currentTime() > checkpointClock.currentTime() + } + + /** + * Advance the checkpoint clock by the checkpoint interval. + */ + def advanceCheckpoint() = { + checkpointClock.addToTime(checkpointInterval.milliseconds) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala new file mode 100644 index 0000000000000..1bd1f324298e7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.net.InetAddress +import java.util.UUID + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.receiver.Receiver + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +/** + * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. + * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: + * https://github.com/awslabs/amazon-kinesis-client + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) + * as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers + * to run within a Spark Executor. + * + * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ +private[kinesis] class KinesisReceiver( + appName: String, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel) + extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + + /* + * The following vars are built in the onStart() method which executes in the Spark Worker after + * this code is serialized and shipped remotely. + */ + + /* + * workerId should be based on the ip address of the actual Spark Worker where this code runs + * (not the Driver's ip address.) + */ + var workerId: String = null + + /* + * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file at the default location (~/.aws/credentials) shared by all + * AWS SDKs and the AWS CLI + * Instance profile credentials delivered through the Amazon EC2 metadata service + */ + var credentialsProvider: AWSCredentialsProvider = null + + /* KCL config instance. */ + var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + var recordProcessorFactory: IRecordProcessorFactory = null + + /* + * Create a Kinesis Worker. + * This is the core client abstraction from the Kinesis Client Library (KCL). + * We pass the RecordProcessorFactory from above as well as the KCL config instance. + * A Kinesis Worker can process 1..* shards from the given stream - each with its + * own RecordProcessor. + */ + var worker: Worker = null + + /** + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through the Worker.run() + * method. + */ + override def onStart() { + workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + credentialsProvider = new DefaultAWSCredentialsProviderChain() + kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, + credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) + recordProcessorFactory = new IRecordProcessorFactory { + override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, + workerId, new KinesisCheckpointState(checkpointInterval)) + } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) + worker.run() + logInfo(s"Started receiver with workerId $workerId") + } + + /** + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + */ + override def onStop() { + worker.shutdown() + logInfo(s"Shut down receiver with workerId $workerId") + workerId = null + credentialsProvider = null + kinesisClientLibConfiguration = null + recordProcessorFactory = null + worker = null + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala new file mode 100644 index 0000000000000..8ecc2d90160b1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.List + +import scala.collection.JavaConversions.asScalaBuffer +import scala.util.Random + +import org.apache.spark.Logging + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. + * This implementation operates on the Array[Byte] from the KinesisReceiver. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * + * @param receiver Kinesis receiver + * @param workerId for logging purposes + * @param checkpointState represents the checkpoint state including the next checkpoint time. + * It's injected here for mocking purposes. + */ +private[kinesis] class KinesisRecordProcessor( + receiver: KinesisReceiver, + workerId: String, + checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { + + /* shardId to be populated during initialize() */ + var shardId: String = _ + + /** + * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * + * @param shardId assigned by the KCL to this particular RecordProcessor. + */ + override def initialize(shardId: String) { + logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") + this.shardId = shardId + } + + /** + * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. + * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() + * and Spark Streaming's Receiver.store(). + * + * @param batch list of records from the Kinesis stream shard + * @param checkpointer used to update Kinesis when this batch has been processed/stored + * in the DStream + */ + override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + if (!receiver.isStopped()) { + try { + /* + * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + */ + batch.foreach(record => receiver.store(record.getData().array())) + + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + + /* + * Checkpoint the sequence number of the last record successfully processed/stored + * in the batch. + * In this implementation, we're checkpointing after the given checkpointIntervalMillis. + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the + * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. + * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). + * However, if the worker dies unexpectedly, a checkpoint may not happen. + * This could lead to records being processed more than once. + */ + if (checkpointState.shouldCheckpoint()) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + s" records for shardId $shardId") + logDebug(s"Checkpoint: Next checkpoint is at " + + s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + } + } catch { + case e: Throwable => { + /* + * If there is a failure within the batch, the batch will not be checkpointed. + * This will potentially cause records since the last checkpoint to be processed + * more than once. + */ + logError(s"Exception: WorkerId $workerId encountered and exception while storing " + + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + throw e + } + } + } else { + /* RecordProcessor has been stopped. */ + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + s" and shardId $shardId. No more records will be processed.") + } + } + + /** + * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: + * 1) the stream is resharding by splitting or merging adjacent shards + * (ShutdownReason.TERMINATE) + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * (ShutdownReason.ZOMBIE) + * + * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE + * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + */ + override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* + * ZOMBIE Use Case. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + */ + case ShutdownReason.ZOMBIE => + + /* Unknown reason. NoOp */ + case _ => + } + } +} + +private[kinesis] object KinesisRecordProcessor extends Logging { + /** + * Retry the given amount of times with a random backoff time (millis) less than the + * given maxBackOffMillis + * + * @param expression expression to evalute + * @param numRetriesLeft number of retries left + * @param maxBackOffMillis: max millis between retries + * + * @return evaluation of the given expression + * @throws Unretryable exception, unexpected exception, + * or any exception that persists after numRetriesLeft reaches 0 + */ + @annotation.tailrec + def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { + util.Try { expression } match { + /* If the function succeeded, evaluate to x. */ + case util.Success(x) => x + /* If the function failed, either retry or throw the exception */ + case util.Failure(e) => e match { + /* Retry: Throttling or other Retryable exception has occurred */ + case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 + => { + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) + } + /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + case _: ShutdownException => { + logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) + throw e + } + /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ + case _: InvalidStateException => { + logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) + throw e + } + /* Throw: Unexpected exception has occurred */ + case _ => { + logError(s"Unexpected, non-retryable exception.", e) + throw e + } + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala new file mode 100644 index 0000000000000..713cac0e293c0 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream +import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + +/** + * Helper class to create Amazon Kinesis Input Stream + * :: Experimental :: + */ +@Experimental +object KinesisUtils { + /** + * Create an InputDStream that pulls messages from a Kinesis stream. + * + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ + def createStream( + ssc: StreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, + checkpointInterval, initialPositionInStream, storageLevel)) + } + + /** + * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. + * + * @param jssc Java StreamingContext object + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return JavaReceiverInputDStream[Array[Byte]] + */ + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { + jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, + endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + } +} diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java new file mode 100644 index 0000000000000..87954a31f60ce --- /dev/null +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +/** + * Demonstrate the use of the KinesisUtils Java API + */ +public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { + @Test + public void testKinesisStream() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); + + ssc.stop(); + } +} diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e01e049595475 --- /dev/null +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala new file mode 100644 index 0000000000000..41dbd64c2b1fa --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions.seqAsJavaList + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers +import org.scalatest.mock.EasyMockSugar + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + */ +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with EasyMockSugar { + + val app = "TestKinesisReceiver" + val stream = "mySparkStream" + val endpoint = "endpoint-url" + val workerId = "dummyWorkerId" + val shardId = "dummyShardId" + + val record1 = new Record() + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + val record2 = new Record() + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) + val batch = List[Record](record1, record2) + + var receiverMock: KinesisReceiver = _ + var checkpointerMock: IRecordProcessorCheckpointer = _ + var checkpointClockMock: ManualClock = _ + var checkpointStateMock: KinesisCheckpointState = _ + var currentClockMock: Clock = _ + + override def beforeFunction() = { + receiverMock = mock[KinesisReceiver] + checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointClockMock = mock[ManualClock] + checkpointStateMock = mock[KinesisCheckpointState] + currentClockMock = mock[Clock] + } + + test("kinesis utils api") { + val ssc = new StreamingContext(master, framework, batchDuration) + // Tests the API, does not actually test data receiving + val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + ssc.stop() + } + + test("process records including store and checkpoint") { + val expectedCheckpointIntervalMillis = 10 + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).once() + receiverMock.store(record2.getData().array()).once() + checkpointStateMock.shouldCheckpoint().andReturn(true).once() + checkpointerMock.checkpoint().once() + checkpointStateMock.advanceCheckpoint().once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't store and checkpoint when receiver is stopped") { + expecting { + receiverMock.isStopped().andReturn(true).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't checkpoint when exception occurs during store") { + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + } + + test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + } + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + } + } + + test("should add to time when advancing checkpoint") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) + } + } + + test("shutdown should checkpoint if the reason is TERMINATE") { + expecting { + checkpointerMock.checkpoint().once() + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + } + } + + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + expecting { + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + } + } + + test("retry success on first attempt") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis throttling exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new ThrottlingException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis dependency exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry failed after a shutdown exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after an invalid state exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after unexpected exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after exhausing all retries") { + val expectedErrorMessage = "final try error message" + expecting { + checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) + .andThrow(new ThrottlingException(expectedErrorMessage)).once() + } + whenExecuting(checkpointerMock) { + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + exception.getMessage().shouldBe(expectedErrorMessage) + } + } +} diff --git a/pom.xml b/pom.xml index 99ae4b8b33f94..a42759169149b 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,8 @@ 3.0.0 1.7.6 0.7.1 + 1.8.3 + 1.1.0 64m 512m @@ -1011,6 +1013,14 @@ + + + kinesis-asl + + extras/kinesis-asl + + + java8-tests diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1d7cc6dd6aef3..aac621fe53938 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -37,8 +37,8 @@ object BuildCommons { "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = + Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") .map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -62,7 +62,7 @@ object SparkBuild extends PomBuild { var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq.empty if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { - println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.") + println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { From 4c477117bb1ffef463776c86f925d35036f96b7a Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 2 Aug 2014 13:55:28 -0700 Subject: [PATCH 120/170] SPARK-2804: Remove scalalogging-slf4j dependency This also Closes #1701. Author: GuoQiang Li Closes #1208 from witgo/SPARK-1470 and squashes the following commits: 422646b [GuoQiang Li] Remove scalalogging-slf4j dependency --- .../main/scala/org/apache/spark/Logging.scala | 10 ++++++--- sql/catalyst/pom.xml | 5 ----- .../sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +++---- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 ++-- .../apache/spark/sql/catalyst/package.scala | 1 - .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 ++--- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 +++++----- .../spark/sql/catalyst/trees/package.scala | 8 ++++--- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../CompressibleColumnBuilder.scala | 5 +++-- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 2 -- .../spark/sql/columnar/ColumnTypeSuite.scala | 4 ++-- .../hive/thriftserver/HiveThriftServer2.scala | 12 +++++----- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 6 ++--- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 ++--- .../server/SparkSQLOperationManager.scala | 13 ++++++----- .../thriftserver/HiveThriftServer2Suite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- .../org/apache/spark/sql/hive/TestHive.scala | 10 ++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 ++-- .../hive/execution/HiveComparisonTest.scala | 22 +++++++++---------- .../hive/execution/HiveQueryFileTest.scala | 2 +- 30 files changed, 83 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 807ef3e9c9d60..d4f2624061e35 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -39,13 +39,17 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - log_ = LoggerFactory.getLogger(className.stripSuffix("$")) + log_ = LoggerFactory.getLogger(logName) } log_ } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 54fa96baa1e18..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -54,11 +54,6 @@ spark-core_${scala.binary.version} ${project.version} - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - 1.0.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74c0104e5b17f..2ba68cab115fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + logDebug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..e94f2a3bea63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logDebug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + logDebug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f38f99569f207..0913f15888780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4211998f7511a..094ff14552283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NumericType} @@ -92,7 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } new $orderingName() """ - logger.debug(s"Generated Ordering: $code") + logDebug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index ca9642954eb27..bdd07bbeb2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -25,5 +25,4 @@ package object catalyst { */ protected[catalyst] object ScalaReflectionLock - protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 781ba489b44c6..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bc763a4e06e67..90923fe31a063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -184,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -202,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index f8960b3fe7a17..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6aa407c836aec..d192b151ac1c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -73,26 +73,26 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + logDebug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + logTrace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 9a28d035a10a3..d725a92c06f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.Logging + /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -31,8 +33,8 @@ package org.apache.spark.sql.catalyst *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees { +package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) + protected override def logName = "catalyst.trees" + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dad71079c29b9..00dd34aabc389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 4c6675c3c87bf..6ad12a0dcb64d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -101,7 +102,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] copyColumnHeader(rawBuffer, compressedBuffer) - logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") encoder.compress(rawBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 30712f03cab4c..77dc2ad733215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -101,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 70db1ebd3a3e1..a3d2a1c7a51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 0995a4eb6299f..f513eae9c2d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -32,8 +32,6 @@ import org.apache.spark.annotation.DeveloperApi */ package object sql { - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - /** * :: DeveloperApi :: * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 829342215e691..75f653f3280bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - logger.info("buffer = " + buffer + ", expected = " + expected) + logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index ddbc2a79fb512..08d3f983d9e71 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -40,7 +40,7 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logger.warn("Error starting HiveThriftServer2 with given arguments") + logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -49,12 +49,12 @@ private[hive] object HiveThriftServer2 extends Logging { // Set all properties specified via command line. val hiveConf: HiveConf = ss.getConf hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logger.debug(s"HiveConf var: $k=$v") + logDebug(s"HiveConf var: $k=$v") } SessionState.start(ss) - logger.info("Starting SparkContext") + logInfo("Starting SparkContext") SparkSQLEnv.init() SessionState.start(ss) @@ -70,10 +70,10 @@ private[hive] object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(hiveConf) server.start() - logger.info("HiveThriftServer2 started") + logInfo("HiveThriftServer2 started") } catch { case e: Exception => - logger.error("Error starting HiveThriftServer2", e) + logError("Error starting HiveThriftServer2", e) System.exit(-1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index cb17d7ce58ea0..4d0c506c5a397 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index a56b19a4bcda0..d362d599d08ca 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) @@ -40,7 +40,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed - logger.debug(s"Result Schema: ${analyzed.output}") + logDebug(s"Result Schema: ${analyzed.output}") if (analyzed.output.size == 0) { new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) } else { @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo new CommandProcessorResponse(0) } catch { case cause: Throwable => - logger.error(s"Failed in [$command]", cause) + logError(s"Failed in [$command]", cause) new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 451c3bd7b9352..582264eb59f83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { - logger.debug("Initializing SparkSQLEnv") + logDebug("Initializing SparkSQLEnv") var hiveContext: HiveContext = _ var sparkContext: SparkContext = _ @@ -47,7 +47,7 @@ private[hive] object SparkSQLEnv extends Logging { /** Cleans up and shuts down the Spark SQL environments. */ def stop() { - logger.debug("Shutting down Spark SQL Environment") + logDebug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { sparkContext.stop() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a4e1f3e762e89..d4dadfd21d13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,10 +30,11 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -55,7 +56,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logger.debug("CLOSING") + logDebug("CLOSING") } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { @@ -112,7 +113,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { @@ -124,11 +125,11 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def run(): Unit = { - logger.info(s"Running query '$statement'") + logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) try { result = hiveContext.hql(statement) - logger.debug(result.queryExecution.toString()) + logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = result.queryExecution.toRdd.toLocalIterator @@ -138,7 +139,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - logger.error("Error executing query:",e) + logError("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index fe3403b3292ec..b7b7c9957ac34 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -27,7 +27,7 @@ import java.sql.{Connection, DriverManager, Statement} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7e3b8727bebed..2c7270d9f83a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -207,7 +207,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - logger.error( + logError( s""" |====================== |HIVE FAILURE OUTPUT diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fa4e78439c26c..df3604439e483 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLContext, Logging} +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c50e8c4b5c5d3..728452a25a00e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -148,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -273,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - logger.info(s"Loading test table $name") + logInfo(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -312,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - logger.debug(s"Deleting table $t") + logDebug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -325,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logger.debug(s"Dropping Database: $db") + logDebug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -347,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7582b4743d404..d181921269b56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - logger.debug( + logDebug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 6c8fe4b196dea..83cfbc6b4a002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => logDebug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - logger.debug( + logDebug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - logger.debug(s"=== HIVE TEST: $testCaseName ===") + logDebug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,7 +235,7 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + logWarning(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + logDebug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - logger.debug(s"File $cachedAnswerFile not found") + logDebug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - logger.info(s"Using answer cache for test: $testCaseName") + logInfo(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - logger.warn(s"Clearing cache files for failed test $testCaseName") + logWarning(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 50ab71a9003d3..02518d516261b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logger.debug(s"Blacklisted test skipped $testCaseName") + logDebug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) From 158ad0bba9382fd494b4789b5628a9cec00cfa19 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:33:48 -0700 Subject: [PATCH 121/170] [SPARK-2097][SQL] UDF Support This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL. Scala: ```scala registerFunction("strLenScala", (_: String).length) sql("SELECT strLenScala('test')") ``` Python: ```python sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType()) sqlCtx.sql("SELECT strLenPython('test')") ``` Java: ```java sqlContext.registerFunction("stringLengthJava", new UDF1() { Override public Integer call(String str) throws Exception { return str.length(); } }, DataType.IntegerType); sqlContext.sql("SELECT stringLengthJava('test')"); ``` Author: Michael Armbrust Closes #1063 from marmbrus/udfs and squashes the following commits: 9eda0fe [Michael Armbrust] newline 747c05e [Michael Armbrust] Add some scala UDF tests. d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 005d684 [Michael Armbrust] Fix naming and formatting. d14dac8 [Michael Armbrust] Fix last line of autogened java files. 8135c48 [Michael Armbrust] Move UDF unit tests to pyspark. 40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable. 7a83101 [Michael Armbrust] Drop toString 795fd15 [Michael Armbrust] Try to avoid capturing SQLContext. e54fb45 [Michael Armbrust] Docs and tests. 437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments. 01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 8e6c932 [Michael Armbrust] WIP 3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 6237c8d [Michael Armbrust] WIP 2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs. 0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python. --- python/pyspark/sql.py | 39 ++- .../catalyst/analysis/FunctionRegistry.scala | 32 ++ .../sql/catalyst/expressions/ScalaUdf.scala | 307 ++++++++++++++++++ .../org/apache/spark/sql/api/java/UDF1.java | 32 ++ .../org/apache/spark/sql/api/java/UDF10.java | 32 ++ .../org/apache/spark/sql/api/java/UDF11.java | 32 ++ .../org/apache/spark/sql/api/java/UDF12.java | 32 ++ .../org/apache/spark/sql/api/java/UDF13.java | 32 ++ .../org/apache/spark/sql/api/java/UDF14.java | 32 ++ .../org/apache/spark/sql/api/java/UDF15.java | 32 ++ .../org/apache/spark/sql/api/java/UDF16.java | 32 ++ .../org/apache/spark/sql/api/java/UDF17.java | 32 ++ .../org/apache/spark/sql/api/java/UDF18.java | 32 ++ .../org/apache/spark/sql/api/java/UDF19.java | 32 ++ .../org/apache/spark/sql/api/java/UDF2.java | 32 ++ .../org/apache/spark/sql/api/java/UDF20.java | 32 ++ .../org/apache/spark/sql/api/java/UDF21.java | 32 ++ .../org/apache/spark/sql/api/java/UDF22.java | 32 ++ .../org/apache/spark/sql/api/java/UDF3.java | 32 ++ .../org/apache/spark/sql/api/java/UDF4.java | 32 ++ .../org/apache/spark/sql/api/java/UDF5.java | 32 ++ .../org/apache/spark/sql/api/java/UDF6.java | 32 ++ .../org/apache/spark/sql/api/java/UDF7.java | 32 ++ .../org/apache/spark/sql/api/java/UDF8.java | 32 ++ .../org/apache/spark/sql/api/java/UDF9.java | 32 ++ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../apache/spark/sql/UdfRegistration.scala | 196 +++++++++++ .../spark/sql/api/java/JavaSQLContext.scala | 5 +- .../spark/sql/api/java/UDFRegistration.scala | 252 ++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 177 ++++++++++ .../spark/sql/api/java/JavaAPISuite.java | 90 +++++ .../apache/spark/sql/InsertIntoSuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 36 ++ .../apache/spark/sql/hive/HiveContext.scala | 13 +- .../org/apache/spark/sql/hive/TestHive.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 4 +- 38 files changed, 1861 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f840475ffaf70..e7c35ac1ffe02 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -28,9 +28,13 @@ from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer + +from itertools import chain, ifilter, imap from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter + __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -932,6 +936,39 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + [Row(c0=5)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, + BatchedSerializer(PickleSerializer(), 1024), + BatchedSerializer(PickleSerializer(), 1024)) + env = MapConverter().convert(self._sc.environment, + self._sc._gateway._gateway_client) + includes = ListConverter().convert(self._sc._python_includes, + self._sc._gateway._gateway_client) + self._ssql_ctx.registerPython(name, + bytearray(CloudPickleSerializer().dumps(command)), + env, + includes, + self._sc.pythonExec, + self._sc._javaAccumulator, + str(returnType)) + def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}s. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c0255701b7ba5..760c49fbca4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -18,17 +18,49 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { + type FunctionBuilder = Seq[Expression] => Expression + + def registerFunction(name: String, builder: FunctionBuilder): Unit + def lookupFunction(name: String, children: Seq[Expression]): Expression } +trait OverrideFunctionRegistry extends FunctionRegistry { + + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + } +} + +class SimpleFunctionRegistry extends FunctionRegistry { + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders(name)(children) + } +} + /** * A trivial catalog that returns an error when a function is requested. Used for testing when all * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { + def registerFunction(name: String, builder: FunctionBuilder) = ??? + def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index acddf5e9c7004..95633dd0c9870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,6 +27,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true + /** This method has been generated by this script + + (1 to 22).map { x => + val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) + val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + + s""" + case $x => + function.asInstanceOf[($anys) => Any]( + $evals) + """ + } + + */ + + // scalastyle:off override def eval(input: Row): Any = { children.size match { case 0 => function.asInstanceOf[() => Any]() @@ -35,6 +51,297 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi function.asInstanceOf[(Any, Any) => Any]( children(0).eval(input), children(1).eval(input)) + case 3 => + function.asInstanceOf[(Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input)) + case 4 => + function.asInstanceOf[(Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input)) + case 5 => + function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input)) + case 6 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input)) + case 7 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input)) + case 8 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input)) + case 9 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input)) + case 10 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input)) + case 11 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input)) + case 12 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input)) + case 13 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input)) + case 14 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input)) + case 15 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input)) + case 16 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input)) + case 17 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input)) + case 18 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input)) + case 19 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input)) + case 20 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input)) + case 21 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input)) + case 22 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input), + children(21).eval(input)) } + // scalastyle:on } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java new file mode 100644 index 0000000000000..ef959e35e1027 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 1 arguments. + */ +public interface UDF1 extends Serializable { + public R call(T1 t1) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java new file mode 100644 index 0000000000000..96ab3a96c3d5e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 10 arguments. + */ +public interface UDF10 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java new file mode 100644 index 0000000000000..58ae8edd6d817 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 11 arguments. + */ +public interface UDF11 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java new file mode 100644 index 0000000000000..d9da0f6eddd94 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 12 arguments. + */ +public interface UDF12 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java new file mode 100644 index 0000000000000..095fc1a8076b5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 13 arguments. + */ +public interface UDF13 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java new file mode 100644 index 0000000000000..eb27eaa180086 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 14 arguments. + */ +public interface UDF14 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java new file mode 100644 index 0000000000000..1fbcff56332b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 15 arguments. + */ +public interface UDF15 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java new file mode 100644 index 0000000000000..1133561787a69 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 16 arguments. + */ +public interface UDF16 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java new file mode 100644 index 0000000000000..dfae7922c9b63 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 17 arguments. + */ +public interface UDF17 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java new file mode 100644 index 0000000000000..e9d1c6d52d4ea --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 18 arguments. + */ +public interface UDF18 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java new file mode 100644 index 0000000000000..46b9d2d3c9457 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 19 arguments. + */ +public interface UDF19 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java new file mode 100644 index 0000000000000..cd3fde8da419e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 2 arguments. + */ +public interface UDF2 extends Serializable { + public R call(T1 t1, T2 t2) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java new file mode 100644 index 0000000000000..113d3d26be4a7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 20 arguments. + */ +public interface UDF20 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java new file mode 100644 index 0000000000000..74118f2cf8da7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 21 arguments. + */ +public interface UDF21 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java new file mode 100644 index 0000000000000..0e7cc40be45ec --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 22 arguments. + */ +public interface UDF22 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java new file mode 100644 index 0000000000000..6a880f16be47a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 3 arguments. + */ +public interface UDF3 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java new file mode 100644 index 0000000000000..fcad2febb18e6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 4 arguments. + */ +public interface UDF4 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java new file mode 100644 index 0000000000000..ce0cef43a2144 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 5 arguments. + */ +public interface UDF5 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java new file mode 100644 index 0000000000000..f56b806684e61 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 6 arguments. + */ +public interface UDF6 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java new file mode 100644 index 0000000000000..25bd6d3241bd4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 7 arguments. + */ +public interface UDF7 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java new file mode 100644 index 0000000000000..a3b7ac5f94ce7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 8 arguments. + */ +public interface UDF8 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java new file mode 100644 index 0000000000000..205e72a1522fc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 9 arguments. + */ +public interface UDF9 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 00dd34aabc389..33931e5d996f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -48,18 +48,23 @@ import org.apache.spark.{Logging, SparkContext} */ @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) - extends Logging + extends org.apache.spark.Logging with SQLConf with ExpressionConversions + with UDFRegistration with Serializable { self => @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + + @transient + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) @transient protected[sql] val optimizer = Optimizer @transient @@ -379,7 +384,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - lazy val analyzed = analyzer(logical) + lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... lazy val sparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala new file mode 100644 index 0000000000000..0b48e9e659faa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.execution.PythonUDF + +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +/** + * Functions for registering scala lambda functions as UDFs in a SQLContext. + */ +protected[sql] trait UDFRegistration { + self: SQLContext => + + private[spark] def registerPython( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + stringDataType: String): Unit = { + log.debug( + s""" + | Registering new PythonUDF: + | name: $name + | command: ${command.toSeq} + | envVars: $envVars + | pythonIncludes: $pythonIncludes + | pythonExec: $pythonExec + | dataType: $stringDataType + """.stripMargin) + + + val dataType = parseDataType(stringDataType) + + def builder(e: Seq[Expression]) = + PythonUDF( + name, + command, + envVars, + pythonIncludes, + pythonExec, + accumulator, + dataType, + e) + + functionRegistry.registerFunction(name, builder) + } + + /** registerFunction 1-22 were generated by this script + + (1 to 22).map { x => + val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + s""" + def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { + def builder(e: Seq[Expression]) = + ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + """ + } + */ + + // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 809dd038f94aa..ae45193ed15d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.types.util.DataTypeConversions -import DataTypeConversions.asScalaDataType; +import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils /** * The entry point for executing Spark SQL queries from a Java program. */ -class JavaSQLContext(val sqlContext: SQLContext) { +class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala new file mode 100644 index 0000000000000..158f26e3d445f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala @@ -0,0 +1,252 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.types.util.DataTypeConversions._ + +/** + * A collection of functions that allow Java users to register UDFs. In order to handle functions + * of varying airities with minimal boilerplate for our users, we generate classes and functions + * for each airity up to 22. The code for this generation can be found in comments in this trait. + */ +private[java] trait UDFRegistration { + self: JavaSQLContext => + + /* The following functions and required interfaces are generated with these code fragments: + + (1 to 22).foreach { i => + val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + println(s""" + |def registerFunction( + | name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: DataType) = { + | val scalaType = asScalaDataType(dataType) + | sqlContext.functionRegistry.registerFunction( + | name, + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), scalaType, e)) + |} + """.stripMargin) + } + + import java.io.File + import org.apache.spark.sql.catalyst.util.stringToFile + val directory = new File("sql/core/src/main/java/org/apache/spark/sql/api/java/") + (1 to 22).foreach { i => + val typeArgs = (1 to i).map(i => s"T$i").mkString(", ") + val args = (1 to i).map(i => s"T$i t$i").mkString(", ") + + val contents = + s"""/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.sql.api.java; + | + |import java.io.Serializable; + | + |// ************************************************** + |// THIS FILE IS AUTOGENERATED BY CODE IN + |// org.apache.spark.sql.api.java.FunctionRegistration + |// ************************************************** + | + |/** + | * A Spark SQL UDF that has $i arguments. + | */ + |public interface UDF$i<$typeArgs, R> extends Serializable { + | public R call($args) throws Exception; + |} + |""".stripMargin + + stringToFile(new File(directory, s"UDF$i.java"), contents) + } + + */ + + // scalastyle:off + def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8bec015c7b465..f0c958fdb537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -286,6 +286,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case e @ EvaluatePython(udf, child) => + BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala new file mode 100644 index 0000000000000..b92091b560b1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -0,0 +1,177 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution + +import java.util.{List => JList, Map => JMap} + +import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.{Accumulator, Logging => SparkLogging} + +import scala.collection.JavaConversions._ + +/** + * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + */ +private[spark] case class PythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with SparkLogging { + + override def toString = s"PythonUDF#$name(${children.mkString(",")})" + + def nullable: Boolean = true + def references: Set[Attribute] = children.flatMap(_.references).toSet + + override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") +} + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan) = plan transform { + // Skip EvaluatePython nodes. + case p: EvaluatePython => p + + case l: LogicalPlan => + // Extract any PythonUDFs from the current operator. + val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + if (udfs.isEmpty) { + // If there aren't any, we are done. + l + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + val udf = udfs.head + + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = l.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartisian product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + l.output, + l.transformExpressions { + case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + }.withNewChildren(newChildren)) + } + } +} + +/** + * :: DeveloperApi :: + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +@DeveloperApi +case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { + val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() + + def references = Set.empty + def output = child.output :+ resultAttribute +} + +/** + * :: DeveloperApi :: + * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. The input + * data is cached and zipped with the result of the udf evaluation. + */ +@DeveloperApi +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + def children = child :: Nil + + def execute() = { + // TODO: Clean up after ourselves? + val childResults = child.execute().map(_.copy()).cache() + + val parent = childResults.mapPartitions { iter => + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + iter.grouped(1000).map { inputRows => + val toBePickled = inputRows.map(currentRow(_).toArray).toArray + pickle.dumps(toBePickled) + } + } + + val pyRDD = new PythonRDD( + parent, + udf.command, + udf.envVars, + udf.pythonIncludes, + false, + udf.pythonExec, + Seq[Broadcast[Array[Byte]]](), + udf.accumulator + ).mapPartitions { iter => + val pickle = new Unpickler + iter.flatMap { pickedResult => + val unpickledBatch = pickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + } + }.mapPartitions { iter => + val row = new GenericMutableRow(1) + iter.map { result => + row(0) = udf.dataType match { + case StringType => result.toString + case other => result + } + row: Row + } + } + + childResults.zip(pyRDD).mapPartitions { iter => + val joinedRow = new JoinedRow() + iter.map { + case (row, udfResult) => + joinedRow(row, udfResult) + } + } + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java new file mode 100644 index 0000000000000..a9a11285def54 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.sql.api.java.UDF1; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Suite; +import org.junit.runner.RunWith; + +import org.apache.spark.api.java.JavaSparkContext; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaSQLContext sqlContext; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + sqlContext = new JavaSQLContext(sc); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test')").first(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 4f0b85f26254b..23a711d08c58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.File +import _root_.java.io.File /* Implicits */ import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala new file mode 100644 index 0000000000000..76aa9b0081d7e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class UDFSuite extends QueryTest { + + test("Simple UDF") { + registerFunction("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + } + + test("TwoArgument UDF") { + registerFunction("strLenScala", (_: String).length + (_:Int)) + assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2c7270d9f83a9..3c70b3f0921a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver @@ -35,8 +35,9 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand @@ -155,10 +156,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + // Note that HiveUDFs will be overridden by functions registered in this context. + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry + /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -250,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = - optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 728452a25a00e..c605e8adcfb0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -297,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => - logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d181921269b56..179aac5cbd5cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { +private[hive] abstract class HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) @@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu } private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf { + extends HiveUdf with HiveInspectors { - import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 11d8b1f0a3d96..95921c3d7ae09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -51,9 +51,9 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == - |$e + |${stackTraceToString(e)} """.stripMargin) } From 198df11f1a9f419f820f47eba0e9f2ab371a824b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:48:07 -0700 Subject: [PATCH 122/170] [SPARK-2785][SQL] Remove assertions that throw when users try unsupported Hive commands. Author: Michael Armbrust Closes #1742 from marmbrus/asserts and squashes the following commits: 5182d54 [Michael Armbrust] Remove assertions that throw when users try unsupported Hive commands. --- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3d2eb1eefaeda..bc2fefafd58c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -297,8 +297,11 @@ private[hive] object HiveQl { matches.headOption } - assert(remainingNodes.isEmpty, - s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } clauses } @@ -748,7 +751,10 @@ private[hive] object HiveQl { case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - assert(other.size <= 1, s"Unhandled join child $other") + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + val joinType = joinToken match { case "TOK_JOIN" => Inner case "TOK_RIGHTOUTERJOIN" => RightOuter @@ -756,7 +762,6 @@ private[hive] object HiveQl { case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - assert(other.size <= 1, "Unhandled join clauses.") Join(nodeToRelation(relation1), nodeToRelation(relation2), joinType, From 866cf1f822cfda22294054be026ef2d96307eb75 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 2 Aug 2014 17:12:49 -0700 Subject: [PATCH 123/170] [SPARK-2729][SQL] Added test case for SPARK-2729 This is a follow up of #1636. Author: Cheng Lian Closes #1738 from liancheng/test-for-spark-2729 and squashes the following commits: b13692a [Cheng Lian] Added test case for SPARK-2729 --- .../test/scala/org/apache/spark/sql/TestData.scala | 12 ++++++++++-- .../sql/columnar/InMemoryColumnarQuerySuite.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 58cee21e8ad4c..088e6e3c843aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) @@ -40,7 +42,7 @@ object TestData { LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) largeAndSmallInts.registerAsTable("largeAndSmallInts") - + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( @@ -143,4 +145,10 @@ object TestData { "2, B2, false, null" :: "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) + + case class TimestampField(time: Timestamp) + val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + TimestampField(new Timestamp(i)) + }) + timestamps.registerAsTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 86727b93f3659..b561b44ad7ee2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -73,4 +73,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) } + + test("SPARK-2729 regression: timestamp data type") { + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + + TestSQLContext.cacheTable("timestamps") + + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + } } From d210022e96804e59e42ab902e53637e50884a9ab Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 17:55:22 -0700 Subject: [PATCH 124/170] [SPARK-2797] [SQL] SchemaRDDs don't support unpersist() The cause is explained in https://issues.apache.org/jira/browse/SPARK-2797. Author: Yin Huai Closes #1745 from yhuai/SPARK-2797 and squashes the following commits: 7b1627d [Yin Huai] The unpersist method of the Scala RDD cannot be called without the input parameter (blocking) from PySpark. --- python/pyspark/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e7c35ac1ffe02..36e50e49c9a9c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1589,9 +1589,9 @@ def persist(self, storageLevel): self._jschema_rdd.persist(javaStorageLevel) return self - def unpersist(self): + def unpersist(self, blocking=True): self.is_cached = False - self._jschema_rdd.unpersist() + self._jschema_rdd.unpersist(blocking) return self def checkpoint(self): From 1a8043739dc1d9435def6ea3c6341498ba52b708 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 18:27:04 -0700 Subject: [PATCH 125/170] [SPARK-2739][SQL] Rename registerAsTable to registerTempTable There have been user complaints that the difference between `registerAsTable` and `saveAsTable` is too subtle. This PR addresses this by renaming `registerAsTable` to `registerTempTable`, which more clearly reflects what is happening. `registerAsTable` remains, but will cause a deprecation warning. Author: Michael Armbrust Closes #1743 from marmbrus/registerTempTable and squashes the following commits: d031348 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 4dff086 [Michael Armbrust] Fix .java files too 89a2f12 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 0b7b71e [Michael Armbrust] Rename registerAsTable to registerTempTable --- .../sbt_app_sql/src/main/scala/SqlApp.scala | 2 +- docs/sql-programming-guide.md | 18 ++++++------ .../spark/examples/sql/JavaSparkSQL.java | 8 +++--- .../spark/examples/sql/RDDRelation.scala | 4 +-- .../examples/sql/hive/HiveFromSpark.scala | 2 +- python/pyspark/sql.py | 12 +++++--- .../org/apache/spark/sql/SQLContext.scala | 4 +-- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 5 +++- .../spark/sql/api/java/JavaSQLContext.scala | 2 +- .../sql/api/java/JavaApplySchemaSuite.java | 6 ++-- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../apache/spark/sql/InsertIntoSuite.scala | 4 +-- .../org/apache/spark/sql/JoinSuite.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 ++-- .../sql/ScalaReflectionRelationSuite.scala | 8 +++--- .../scala/org/apache/spark/sql/TestData.scala | 28 +++++++++---------- .../spark/sql/api/java/JavaSQLSuite.scala | 10 +++---- .../org/apache/spark/sql/json/JsonSuite.scala | 22 +++++++-------- .../spark/sql/parquet/ParquetQuerySuite.scala | 26 ++++++++--------- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 4 +-- .../sql/hive/execution/HiveQuerySuite.scala | 6 ++-- .../hive/execution/HiveResolutionSuite.scala | 4 +-- .../spark/sql/parquet/HiveParquetSuite.scala | 8 +++--- 25 files changed, 103 insertions(+), 96 deletions(-) diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index 50af90c213b5a..d888de929fdda 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -38,7 +38,7 @@ object SparkSqlExample { import sqlContext._ val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) - people.registerAsTable("people") + people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() teenagerNames.foreach(println) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7261badd411a9..0465468084cee 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -142,7 +142,7 @@ case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -210,7 +210,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m // Apply a schema to an RDD of JavaBeans and register it as a table. JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); -schemaPeople.registerAsTable("people"); +schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -248,7 +248,7 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables schemaPeople = sqlContext.inferSchema(people) -schemaPeople.registerAsTable("people") +schemaPeople.registerTempTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -292,7 +292,7 @@ people.saveAsParquetFile("people.parquet") val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile") +parquetFile.registerTempTable("parquetFile") val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -314,7 +314,7 @@ schemaPeople.saveAsParquetFile("people.parquet"); JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -340,7 +340,7 @@ schemaPeople.saveAsParquetFile("people.parquet") parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): @@ -378,7 +378,7 @@ people.printSchema() // |-- name: StringType // Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -416,7 +416,7 @@ people.printSchema(); // |-- name: StringType // Register this JavaSchemaRDD as a table. -people.registerAsTable("people"); +people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -455,7 +455,7 @@ people.printSchema() # |-- name: StringType # Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 607df3eddd550..898297dc658ba 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -74,7 +74,7 @@ public Person call(String line) throws Exception { // Apply a schema to an RDD of Java Beans and register it as a table. JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); - schemaPeople.registerAsTable("people"); + schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -100,7 +100,7 @@ public String call(Row row) { JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. - parquetFile.registerAsTable("parquetFile"); + parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.map(new Function() { @@ -128,7 +128,7 @@ public String call(Row row) { // |-- name: StringType // Register this JavaSchemaRDD as a table. - peopleFromJsonFile.registerAsTable("people"); + peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -158,7 +158,7 @@ public String call(Row row) { // | |-- state: StringType // |-- name: StringType - peopleFromJsonRDD.registerAsTable("people2"); + peopleFromJsonRDD.registerTempTable("people2"); JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.map(new Function() { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 63db688bfb8c0..d56d64c564200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -36,7 +36,7 @@ object RDDRelation { val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -66,7 +66,7 @@ object RDDRelation { parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) // These files can also be registered as tables. - parquetFile.registerAsTable("parquetFile") + parquetFile.registerTempTable("parquetFile") sql("SELECT * FROM parquetFile").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index dc5290fb4f10e..12530c8490b09 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -56,7 +56,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 36e50e49c9a9c..42b738e112809 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -909,7 +909,7 @@ def __init__(self, sparkContext, sqlContext=None): ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerAsTable("allTypes") + >>> srdd.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] @@ -1486,19 +1486,23 @@ def saveAsParquetFile(self, path): """ self._jschema_rdd.saveAsParquetFile(path) - def registerAsTable(self, name): + def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this SchemaRDD. >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerAsTable("test") + >>> srdd.registerTempTable("test") >>> srdd2 = sqlCtx.sql("select * from test") >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - self._jschema_rdd.registerAsTable(name) + self._jschema_rdd.registerTempTable(name) + + def registerAsTable(self, name): + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 33931e5d996f5..567f4dca991b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -116,7 +116,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerAsTable("people") + * peopleSchemaRDD.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * @@ -212,7 +212,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * import sqlContext._ * * case class Person(name: String, age: Int) - * createParquetFile[Person]("path/to/file.parquet").registerAsTable("people") + * createParquetFile[Person]("path/to/file.parquet").registerTempTable("people") * sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d34f62dc8865e..57df79321b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -67,7 +67,7 @@ import org.apache.spark.api.java.JavaRDD * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) * // Any RDD containing case classes can be registered as a table. The schema of the table is * // automatically inferred using scala reflection. - * rdd.registerAsTable("records") + * rdd.registerTempTable("records") * * val results: SchemaRDD = sql("SELECT * FROM records") * }}} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 6a20def475822..2f3033a5f94f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -83,10 +83,13 @@ private[sql] trait SchemaRDDLike { * * @group schema */ - def registerAsTable(tableName: String): Unit = { + def registerTempTable(tableName: String): Unit = { sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) } + @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") + def registerAsTable(tableName: String): Unit = registerTempTable(tableName) + /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ae45193ed15d3..dbaa16e8b0c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -52,7 +52,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) * - * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerAsTable("people") + * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerTempTable("people") * sqlCtx.sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 3c92906d82864..33e5020bc636a 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,7 +98,7 @@ public Row call(Person person) throws Exception { StructType schema = DataType.createStructType(fields); JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); - schemaRDD.registerAsTable("people"); + schemaRDD.registerTempTable("people"); List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -149,14 +149,14 @@ public void applySchemaToJSON() { JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); StructType actualSchema1 = schemaRDD1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerAsTable("jsonTable1"); + schemaRDD1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); Assert.assertEquals(expectedResult, actual1); JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = schemaRDD2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD1.registerAsTable("jsonTable2"); + schemaRDD1.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c3c0dcb1aa00b..fbf9bd9dbcdea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -78,7 +78,7 @@ class CachedTableSuite extends QueryTest { } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") + TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") TestSQLContext.cacheTable("selectStar") TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 23a711d08c58b..c87d762751e6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -31,7 +31,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertTest") + testFile.registerTempTable("createAndInsertTest") // Add some data. testData.insertInto("createAndInsertTest") @@ -86,7 +86,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertSQLTest") + testFile.registerTempTable("createAndInsertSQLTest") sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2fc80588182d9..6c7697ece8c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -285,8 +285,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("full outer join") { - upperCaseData.where('N <= 4).registerAsTable("left") - upperCaseData.where('N >= 3).registerAsTable("right") + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") val left = UnresolvedRelation(None, "left", None) val right = UnresolvedRelation(None, "right", None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5c571d35d1bb9..9b2a36d33fca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -461,7 +461,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerAsTable("applySchema1") + schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), (1, "A1", true, null) :: @@ -491,7 +491,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerAsTable("applySchema2") + schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), (Seq(1, true), Map("A1" -> null)) :: @@ -516,7 +516,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerAsTable("applySchema3") + schemaRDD3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index f2934da9a031d..5b84c658db942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -61,7 +61,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectData") + rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) } @@ -69,7 +69,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectNullData") + rdd.registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) } @@ -77,7 +77,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectOptionalData") + rdd.registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) } @@ -85,7 +85,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerAsTable("reflectBinary") + rdd.registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 088e6e3c843aa..c3ec82fb69778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,7 +30,7 @@ case class TestData(key: Int, value: String) object TestData { val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) val largeAndSmallInts: SchemaRDD = @@ -41,7 +41,7 @@ object TestData { LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) - largeAndSmallInts.registerAsTable("largeAndSmallInts") + largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = @@ -52,7 +52,7 @@ object TestData { TestData2(2, 2) :: TestData2(3, 1) :: TestData2(3, 2) :: Nil) - testData2.registerAsTable("testData2") + testData2.registerTempTable("testData2") // TODO: There is no way to express null primitives as case classes currently... val testData3 = @@ -71,7 +71,7 @@ object TestData { UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: UpperCaseData(6, "F") :: Nil) - upperCaseData.registerAsTable("upperCaseData") + upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) val lowerCaseData = @@ -80,14 +80,14 @@ object TestData { LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil) - lowerCaseData.registerAsTable("lowerCaseData") + lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerAsTable("arrayData") + arrayData.registerTempTable("arrayData") case class MapData(data: Map[Int, String]) val mapData = @@ -97,18 +97,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerAsTable("mapData") + mapData.registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerAsTable("repeatedData") + repeatedData.registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerAsTable("nullableRepeatedData") + nullableRepeatedData.registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -118,7 +118,7 @@ object TestData { NullInts(3) :: NullInts(null) :: Nil ) - nullInts.registerAsTable("nullInts") + nullInts.registerTempTable("nullInts") val allNulls = TestSQLContext.sparkContext.parallelize( @@ -126,7 +126,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: Nil) - allNulls.registerAsTable("allNulls") + allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) val nullStrings = @@ -134,10 +134,10 @@ object TestData { NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil) - nullStrings.registerAsTable("nullStrings") + nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -150,5 +150,5 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerAsTable("timestamps") + timestamps.registerTempTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 020baf0c7ec6f..203ff847e94cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -59,7 +59,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(person :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean]) - schemaRDD.registerAsTable("people") + schemaRDD.registerTempTable("people") javaSqlCtx.sql("SELECT * FROM people").collect() } @@ -76,7 +76,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -101,7 +101,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -127,7 +127,7 @@ class JavaSQLSuite extends FunSuite { var schemaRDD = javaSqlCtx.jsonRDD(rdd) - schemaRDD.registerAsTable("jsonTable1") + schemaRDD.registerTempTable("jsonTable1") assert( javaSqlCtx.sql("select * from jsonTable1").collect.head.row === @@ -144,7 +144,7 @@ class JavaSQLSuite extends FunSuite { rdd.saveAsTextFile(path) schemaRDD = javaSqlCtx.jsonFile(path) - schemaRDD.registerAsTable("jsonTable2") + schemaRDD.registerTempTable("jsonTable2") assert( javaSqlCtx.sql("select * from jsonTable2").collect.head.row === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 9d9cfdd7c92e3..75c0589eb208e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -183,7 +183,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -223,7 +223,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -291,7 +291,7 @@ class JsonSuite extends QueryTest { ignore("Complex field and type inferring (Ignored)") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -320,7 +320,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -374,7 +374,7 @@ class JsonSuite extends QueryTest { ignore("Type conflict in primitive field values (Ignored)") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -445,7 +445,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -466,7 +466,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -494,7 +494,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { @@ -514,7 +514,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -546,7 +546,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD1.schema) - jsonSchemaRDD1.registerAsTable("jsonTable1") + jsonSchemaRDD1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -563,7 +563,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD2.schema) - jsonSchemaRDD2.registerAsTable("jsonTable2") + jsonSchemaRDD2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 8955455ec98c7..9933575038bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -101,9 +101,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA ParquetTestData.writeNestedFile3() ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) - .registerAsTable("testfiltersource") + .registerTempTable("testfiltersource") } override def afterAll() { @@ -247,7 +247,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Creating case class RDD table") { TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - .registerAsTable("tmp") + .registerTempTable("tmp") val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) var counter = 1 rdd.foreach { @@ -266,7 +266,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .map(i => TestRDDEntry(i, s"val_$i")) rdd.saveAsParquetFile(path) val readFile = parquetFile(path) - readFile.registerAsTable("tmpx") + readFile.registerTempTable("tmpx") val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { @@ -280,9 +280,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val dirname = Utils.createTempDir() val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - source_rdd.registerAsTable("source") + source_rdd.registerTempTable("source") val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) - dest_rdd.registerAsTable("dest") + dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) @@ -547,7 +547,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) @@ -562,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -589,7 +589,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -608,7 +608,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() assert(result1.size === 1) assert(result1(0)(0) @@ -625,7 +625,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) @@ -658,7 +658,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpcopy") + .registerTempTable("tmpcopy") val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) @@ -679,7 +679,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpmapcopy") + .registerTempTable("tmpmapcopy") val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 833f3502154f3..7e323146f9da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -28,7 +28,7 @@ case class TestData(key: Int, value: String) class InsertIntoHiveTableSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") test("insertInto() HiveTable") { createTable[TestData]("createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 10c8069a624e6..578f27574ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -63,7 +63,7 @@ class JavaHiveQLSuite extends FunSuite { javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + javaHiveCtx.hql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx @@ -73,7 +73,7 @@ class JavaHiveQLSuite extends FunSuite { .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + javaHiveCtx.hql(s"DESCRIBE $tableName").registerTempTable("describe_table") javaHiveCtx .hql("SELECT result FROM describe_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 89cc589fb8001..4ed41550cf530 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -247,7 +247,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerAsTable("REGisteredTABle") + testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -272,7 +272,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -401,7 +401,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerAsTable("test_describe_commands2") + testData.registerTempTable("test_describe_commands2") assertResult( Array( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index fb03db12a0b01..2455c18925dfa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -54,14 +54,14 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("caseSensitivityTest") + .registerTempTable("caseSensitivityTest") hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("nestedRepeatedTest") + .registerTempTable("nestedRepeatedTest") assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 47526e3596e44..6545e8d7dcb69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -41,7 +41,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft // write test data ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") } override def afterAll() { @@ -67,7 +67,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft .map(i => Cases(i, i)) .saveAsParquetFile(tempFile.getCanonicalPath) - parquetFile(tempFile.getCanonicalPath).registerAsTable("cases") + parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) } @@ -86,7 +86,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Converting Hive to Parquet Table via saveAsParquetFile") { hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) @@ -94,7 +94,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("INSERT OVERWRITE TABLE Parquet table") { hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() From 33f167d762483b55d5d874dcc1e3075f661d4375 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:44:19 -0700 Subject: [PATCH 126/170] SPARK-2602 [BUILD] Tests steal focus under Java 6 As per https://issues.apache.org/jira/browse/SPARK-2602 , this may be resolved for Java 6 with the java.awt.headless system property, which never hurt anyone running a command line app. I tested it and seemed to get rid of focus stealing. Author: Sean Owen Closes #1747 from srowen/SPARK-2602 and squashes the following commits: b141018 [Sean Owen] Set java.awt.headless during tests --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index a42759169149b..cc9377cec2a07 100644 --- a/pom.xml +++ b/pom.xml @@ -871,6 +871,7 @@ -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + true ${session.executionRootDirectory} 1 From 9cf429aaf529e91f619910c33cfe46bf33a66982 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:55:56 -0700 Subject: [PATCH 127/170] SPARK-2414 [BUILD] Add LICENSE entry for jquery The JIRA concerned removing jquery, and this does not remove jquery. While it is distributed by Spark it should have an accompanying line in LICENSE, very technically, as per http://www.apache.org/dev/licensing-howto.html Author: Sean Owen Closes #1748 from srowen/SPARK-2414 and squashes the following commits: 2fdb03c [Sean Owen] Add LICENSE entry for jquery --- LICENSE | 1 + 1 file changed, 1 insertion(+) diff --git a/LICENSE b/LICENSE index 76a3601c66918..e9a1153fdc5db 100644 --- a/LICENSE +++ b/LICENSE @@ -549,3 +549,4 @@ The following components are provided under the MIT License. See project link fo (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (MIT License) jquery (https://jquery.org/license/) From 3dc55fdf450b4237f7c592fce56d1467fd206366 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 22:00:46 -0700 Subject: [PATCH 128/170] [Minor] Fixes on top of #1679 Minor fixes on top of #1679. Author: Andrew Or Closes #1736 from andrewor14/amend-#1679 and squashes the following commits: 3b46f5e [Andrew Or] Minor fixes --- .../org/apache/spark/storage/BlockManagerSource.scala | 5 ++--- .../scala/org/apache/spark/storage/StorageUtils.scala | 11 ++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index e939318a029dd..3f14c40ec61cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -46,9 +46,8 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - val remainingMem = storageStatusList.map(_.memRemaining).sum - (maxMem - remainingMem) / 1024 / 1024 + val memUsed = storageStatusList.map(_.memUsed).sum + memUsed / 1024 / 1024 } }) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 0a0a448baa2ef..2bd6b749be261 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -172,16 +172,13 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def memRemaining: Long = maxMem - memUsed /** Return the memory used by this block manager. */ - def memUsed: Long = - _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum /** Return the disk space used by this block manager. */ - def diskUsed: Long = - _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the off-heap space used by this block manager. */ - def offHeapUsed: Long = - _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum + def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) @@ -246,7 +243,7 @@ private[spark] object StorageUtils { val rddId = rddInfo.id // Assume all blocks belonging to the same RDD have the same storage level val storageLevel = statuses - .map(_.rddStorageLevel(rddId)).flatMap(s => s).headOption.getOrElse(StorageLevel.NONE) + .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum From f8cd143b6b1b4d8aac87c229e5af263b0319b3ea Mon Sep 17 00:00:00 2001 From: Stephen Boesch Date: Sun, 3 Aug 2014 10:19:04 -0700 Subject: [PATCH 129/170] SPARK-2712 - Add a small note to maven doc that mvn package must happen ... Per request by Reynold adding small note about proper sequencing of build then test. Author: Stephen Boesch Closes #1615 from javadba/docs and squashes the following commits: 6c3183e [Stephen Boesch] Moved updated testing blurb per PWendell 5764757 [Stephen Boesch] SPARK-2712 - Add a small note to maven doc that mvn package must happen before test --- docs/building-with-maven.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 55a9e37dfed83..672d0ef114f6d 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -98,7 +98,12 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Spark Tests in Maven -Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. You can then run the tests with `mvn -Dhadoop.version=... test`. +Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). + +Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: + + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: From a0bcbc159e89be868ccc96175dbf1439461557e1 Mon Sep 17 00:00:00 2001 From: "Allan Douglas R. de Oliveira" Date: Sun, 3 Aug 2014 10:25:59 -0700 Subject: [PATCH 130/170] SPARK-2246: Add user-data option to EC2 scripts Author: Allan Douglas R. de Oliveira Closes #1186 from douglaz/spark_ec2_user_data and squashes the following commits: 94a36f9 [Allan Douglas R. de Oliveira] Added user data option to EC2 script --- ec2/spark_ec2.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 02cfe4ec39c7d..0c2f85a3868f4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -135,6 +135,10 @@ def parse_args(): "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + "(e.g -Dspark.worker.timeout=180)") + parser.add_option( + "--user-data", type="string", default="", + help="Path to a user-data file (most AMI's interpret this as an initialization script)") + (opts, args) = parser.parse_args() if len(args) != 2: @@ -274,6 +278,12 @@ def launch_cluster(conn, opts, cluster_name): if opts.key_pair is None: print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." sys.exit(1) + + user_data_content = None + if opts.user_data: + with open(opts.user_data) as user_data_file: + user_data_content = user_data_file.read() + print "Setting up security groups..." master_group = get_or_make_group(conn, cluster_name + "-master") slave_group = get_or_make_group(conn, cluster_name + "-slaves") @@ -347,7 +357,8 @@ def launch_cluster(conn, opts, cluster_name): key_name=opts.key_pair, security_groups=[slave_group], instance_type=opts.instance_type, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -398,7 +409,8 @@ def launch_cluster(conn, opts, cluster_name): placement=zone, min_count=num_slaves_this_zone, max_count=num_slaves_this_zone, - block_device_map=block_map) + block_device_map=block_map, + user_data=user_data_content) slave_nodes += slave_res.instances print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, zone, slave_res.id) From 2998e38a942351974da36cb619e863c6f0316e7a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 3 Aug 2014 10:36:52 -0700 Subject: [PATCH 131/170] [SPARK-2197] [mllib] Java DecisionTree bug fix and easy-of-use Bug fix: Before, when an RDD was created in Java and passed to DecisionTree.train(), the fake class tag caused problems. * Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. Other improvements to Decision Trees for easy-of-use with Java: * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor --> Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. CC: mengxr Author: Joseph K. Bradley Closes #1740 from jkbradley/dt-java-new and squashes the following commits: 0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of JavaConversions 519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently. I suspect we will redo the API before the other options are included. d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java 320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written 13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. --- .../spark/mllib/tree/DecisionTree.scala | 8 +- .../mllib/tree/configuration/Strategy.scala | 29 +++++ .../spark/mllib/tree/impurity/Entropy.scala | 7 ++ .../spark/mllib/tree/impurity/Gini.scala | 7 ++ .../spark/mllib/tree/impurity/Variance.scala | 7 ++ .../mllib/tree/JavaDecisionTreeSuite.java | 102 ++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 ++ 7 files changed, 162 insertions(+), 4 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 382e76a9b7cba..1d03e6e3b36cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - input.cache() + val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction @@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index fdad4f029aa99..4ee4bcd0bcbc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.configuration +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.configuration.Algo._ @@ -61,4 +63,31 @@ class Strategy ( val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + /** + * Java-friendly constructor. + * + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification + * @param maxBins maximum number of bins used for splitting features + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, an entry + * (n -> k) implies the feature n is categorical with k categories + * 0, 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 9297c20596527..96d2471e1f88c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -66,4 +66,11 @@ object Entropy extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 2874bcf496484..d586f449048bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -62,4 +62,11 @@ object Gini extends Impurity { @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 698a1a2a8e899..f7d99a40eb380 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -53,4 +53,11 @@ object Variance extends Impurity { val squaredLoss = sumSquares - (sum * sum) / count squaredLoss / count } + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + def instance = this + } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java new file mode 100644 index 0000000000000..2c281a1ee7157 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.configuration.Algo; +import org.apache.spark.mllib.tree.configuration.Strategy; +import org.apache.spark.mllib.tree.impurity.Gini; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; + + +public class JavaDecisionTreeSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + int validatePrediction(List validationData, DecisionTreeModel model) { + int numCorrect = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numCorrect++; + } + } + return numCorrect; + } + + @Test + public void runDTUsingConstructor() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTree learner = new DecisionTree(strategy); + DecisionTreeModel model = learner.train(rdd.rdd()); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + + @Test + public void runDTUsingStaticMethods() { + List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); + JavaRDD rdd = sc.parallelize(arr); + HashMap categoricalFeaturesInfo = new HashMap(); + categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories + + int maxDepth = 4; + int numClasses = 2; + int maxBins = 100; + Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, + maxBins, categoricalFeaturesInfo); + + DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + + int numCorrect = validatePrediction(arr, model); + Assert.assertTrue(numCorrect == rdd.count()); + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8665a00f3b356..70ca7c8a266f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} @@ -815,6 +817,10 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = { + generateCategoricalDataPoints().toList.asJava + } + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { From 236dfac6769016e433b2f6517cda2d308dea74bc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 3 Aug 2014 12:28:29 -0700 Subject: [PATCH 132/170] [SPARK-2784][SQL] Deprecate hql() method in favor of a config option, 'spark.sql.dialect' Many users have reported being confused by the distinction between the `sql` and `hql` methods. Specifically, many users think that `sql(...)` cannot be used to read hive tables. In this PR I introduce a new configuration option `spark.sql.dialect` that picks which dialect with be used for parsing. For SQLContext this must be set to `sql`. In `HiveContext` it defaults to `hiveql` but can also be set to `sql`. The `hql` and `hiveql` methods continue to act the same but are now marked as deprecated. **This is a possibly breaking change for some users unless they set the dialect manually, though this is unlikely.** For example: `hiveContex.sql("SELECT 1")` will now throw a parsing exception by default. Author: Michael Armbrust Closes #1746 from marmbrus/sqlLanguageConf and squashes the following commits: ad375cc [Michael Armbrust] Merge remote-tracking branch 'apache/master' into sqlLanguageConf 20c43f8 [Michael Armbrust] override function instead of just setting the value 7e4ae93 [Michael Armbrust] Deprecate hql() method in favor of a config option, 'spark.sql.dialect' --- .../sbt_app_hive/src/main/scala/HiveApp.scala | 8 +- docs/sql-programming-guide.md | 18 ++-- .../examples/sql/hive/HiveFromSpark.scala | 12 +-- python/pyspark/sql.py | 20 ++-- .../scala/org/apache/spark/sql/SQLConf.scala | 17 +++- .../org/apache/spark/sql/SQLContext.scala | 11 ++- .../spark/sql/api/java/JavaSQLContext.scala | 14 ++- .../hive/thriftserver/SparkSQLDriver.scala | 2 +- .../server/SparkSQLOperationManager.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 26 ++++-- .../sql/hive/api/java/JavaHiveContext.scala | 15 ++- .../spark/sql/hive/CachedTableSuite.scala | 14 +-- .../spark/sql/hive/StatisticsSuite.scala | 10 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 19 ++-- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveQuerySuite.scala | 93 ++++++++++--------- .../hive/execution/HiveResolutionSuite.scala | 6 +- .../execution/HiveTypeCoercionSuite.scala | 2 +- .../sql/hive/execution/HiveUdfSuite.scala | 10 +- .../sql/hive/execution/PruningSuite.scala | 2 +- .../spark/sql/parquet/HiveParquetSuite.scala | 27 +++--- 21 files changed, 199 insertions(+), 133 deletions(-) diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index a21410f3b9813..5111bc0adb772 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -37,10 +37,10 @@ object SparkSqlExample { val hiveContext = new HiveContext(sc) import hiveContext._ - hql("DROP TABLE IF EXISTS src") - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - hql("LOAD DATA LOCAL INPATH 'data.txt' INTO TABLE src") - val results = hql("FROM src SELECT key, value WHERE key >= 0 AND KEY < 5").collect() + sql("DROP TABLE IF EXISTS src") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'data.txt' INTO TABLE src") + val results = sql("FROM src SELECT key, value WHERE key >= 0 AND KEY < 5").collect() results.foreach(println) def test(f: => Boolean, failureMsg: String) = { diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0465468084cee..cd6543945c385 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -495,11 +495,11 @@ directory. // sc is an existing SparkContext. val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hiveContext.hql("FROM src SELECT key, value").collect().foreach(println) +hiveContext.sql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %}
    @@ -515,11 +515,11 @@ expressed in HiveQL. // sc is an existing JavaSparkContext. JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveContext.hql("FROM src SELECT key, value").collect(); +Row[] results = hiveContext.sql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -537,11 +537,11 @@ expressed in HiveQL. from pyspark.sql import HiveContext hiveContext = HiveContext(sc) -hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveContext.hql("FROM src SELECT key, value").collect() +results = hiveContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 12530c8490b09..3423fac0ad303 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -34,20 +34,20 @@ object HiveFromSpark { val hiveContext = new HiveContext(sc) import hiveContext._ - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - hql("SELECT * FROM src").collect.foreach(println) + sql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = hql("SELECT COUNT(*) FROM src").collect().head.getLong(0) + val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") val rddAsStrings = rddFromSql.map { @@ -60,6 +60,6 @@ object HiveFromSpark { // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") - hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42b738e112809..1a829c6fafe03 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1291,16 +1291,20 @@ def _get_hive_ctx(self): def hiveql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as - a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) def hql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as - a L{SchemaRDD}. + DEPRECATED: Use sql() """ + warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", + DeprecationWarning) return self.hiveql(hqlQuery) @@ -1313,16 +1317,16 @@ class LocalHiveContext(HiveContext): >>> import os >>> hiveCtx = LocalHiveContext(sc) >>> try: - ... supress = hiveCtx.hql("DROP TABLE src") + ... supress = hiveCtx.sql("DROP TABLE src") ... except Exception: ... pass >>> kv1 = os.path.join(os.environ["SPARK_HOME"], ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql( + >>> supress = hiveCtx.sql( ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" ... % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value" + >>> results = hiveCtx.sql("FROM src SELECT value" ... ).map(lambda r: int(r.value.split('_')[1])) >>> num = results.count() >>> reduce_sum = results.reduce(lambda x, y: x + y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2d407077be303..40bfd55e95a12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -29,6 +29,7 @@ object SQLConf { val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" val CODEGEN_ENABLED = "spark.sql.codegen" + val DIALECT = "spark.sql.dialect" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -39,7 +40,7 @@ object SQLConf { * A trait that enables the setting and getting of mutable config parameters/hints. * * In the presence of a SQLContext, these can be set and queried by passing SET commands - * into Spark SQL's query functions (sql(), hql(), etc.). Otherwise, users of this trait can + * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this trait can * modify the hints by programmatically calling the setters and getters of this trait. * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). @@ -53,6 +54,20 @@ trait SQLConf { /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? + /** + * The SQL dialect that is used when parsing queries. This defaults to 'sql' which uses + * a simple SQL parser provided by Spark SQL. This is currently the only option for users of + * SQLContext. + * + * When using a HiveContext, this value defaults to 'hiveql', which uses the Hive 0.12.0 HiveQL + * parser. Users can change this to 'sql' if they want to run queries that aren't supported by + * HiveQL (e.g., SELECT 1). + * + * Note that the choice of dialect does not affect things like what tables are available or + * how query execution is performed. + */ + private[spark] def dialect: String = get(DIALECT, "sql") + /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 567f4dca991b2..ecd5fbaa0b094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -248,11 +248,18 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. * * @group userf */ - def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText)) + def sql(sqlText: String): SchemaRDD = { + if (dialect == "sql") { + new SchemaRDD(this, parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect") + } + } /** Returns the specified table as a SchemaRDD */ def table(tableName: String): SchemaRDD = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index dbaa16e8b0c68..150ff8a42063d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -39,10 +39,18 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) /** - * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD + * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. + * + * @group userf */ - def sql(sqlQuery: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) + def sql(sqlText: String): JavaSchemaRDD = { + if (sqlContext.dialect == "sql") { + new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $sqlContext.dialect") + } + } /** * :: Experimental :: diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index d362d599d08ca..7463df1f47d43 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -55,7 +55,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo override def run(command: String): CommandProcessorResponse = { // TODO unify the error code try { - val execution = context.executePlan(context.hql(command).logicalPlan) + val execution = context.executePlan(context.sql(command).logicalPlan) hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index d4dadfd21d13f..dee092159dd4c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -128,7 +128,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) try { - result = hiveContext.hql(statement) + result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3c70b3f0921a5..7db0159512610 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -71,15 +71,29 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => + // Change the default SQL dialect to HiveQL + override private[spark] def dialect: String = get(SQLConf.DIALECT, "hiveql") + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } - /** - * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. - */ + override def sql(sqlText: String): SchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (dialect == "sql") { + super.sql(sqlText) + } else if (dialect == "hiveql") { + new SchemaRDD(this, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") + } + } + + @deprecated("hiveql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - /** An alias for `hiveql`. */ + @deprecated("hql() is deprecated as the sql function now parses using HiveQL by default. " + + s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) /** @@ -95,7 +109,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -125,7 +139,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** * SQLConf and HiveConf contracts: when the hive session is first initialized, params in * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as* + * set() or a SET command inside sql() will be set in the SQLConf *as well as* * in the HiveConf. */ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index c9ee162191c96..a201d2349a2ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.api.java import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.api.java.{JavaSQLContext, JavaSchemaRDD} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.{HiveContext, HiveQl} /** @@ -28,9 +29,21 @@ class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(spa override val sqlContext = new HiveContext(sparkContext) + override def sql(sqlText: String): JavaSchemaRDD = { + // TODO: Create a framework for registering parsers instead of just hardcoding if statements. + if (sqlContext.dialect == "sql") { + super.sql(sqlText) + } else if (sqlContext.dialect == "hiveql") { + new JavaSchemaRDD(sqlContext, HiveQl.parseSql(sqlText)) + } else { + sys.error(s"Unsupported SQL dialect: ${sqlContext.dialect}. Try 'sql' or 'hiveql'") + } + } + /** - * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. + * DEPRECATED: Use sql(...) Instead */ + @Deprecated def hql(hqlQuery: String): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 08da6405a17c6..188579edd7bdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -35,17 +35,17 @@ class CachedTableSuite extends HiveComparisonTest { "SELECT * FROM src LIMIT 1", reset = false) test("Drop cached table") { - hql("CREATE TABLE test(a INT)") + sql("CREATE TABLE test(a INT)") cacheTable("test") - hql("SELECT * FROM test").collect() - hql("DROP TABLE test") + sql("SELECT * FROM test").collect() + sql("DROP TABLE test") intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { - hql("SELECT * FROM test").collect() + sql("SELECT * FROM test").collect() } } test("DROP nonexistant table") { - hql("DROP TABLE IF EXISTS nonexistantTable") + sql("DROP TABLE IF EXISTS nonexistantTable") } test("check that table is cached and uncache") { @@ -74,14 +74,14 @@ class CachedTableSuite extends HiveComparisonTest { } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.hql("CACHE TABLE src") + TestHive.sql("CACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => // Found evidence of caching case _ => fail(s"Table 'src' should be cached") } assert(TestHive.isCached("src"), "Table 'src' should be cached") - TestHive.hql("UNCACHE TABLE src") + TestHive.sql("UNCACHE TABLE src") TestHive.table("src").queryExecution.executedPlan match { case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") case _ => // Found evidence of uncaching diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a61fd9df95c94..d8c77d6021d63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { test("estimates the size of a test MetastoreRelation") { - val rdd = hql("""SELECT * FROM src""") + val rdd = sql("""SELECT * FROM src""") val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } @@ -45,7 +45,7 @@ class StatisticsSuite extends QueryTest { ct: ClassTag[_]) = { before() - var rdd = hql(query) + var rdd = sql(query) // Assert src has a size smaller than the threshold. val sizes = rdd.queryExecution.analyzed.collect { @@ -65,8 +65,8 @@ class StatisticsSuite extends QueryTest { TestHive.settings.synchronized { val tmp = autoBroadcastJoinThreshold - hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") - rdd = hql(query) + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + rdd = sql(query) bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -74,7 +74,7 @@ class StatisticsSuite extends QueryTest { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - hql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") } after() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 578f27574ad2f..9644b707eb1a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -40,7 +40,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("SELECT * FROM src") { assert( - javaHiveCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + javaHiveCtx.sql("SELECT * FROM src").collect().map(_.getInt(0)) === TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) } @@ -56,33 +56,34 @@ class JavaHiveQLSuite extends FunSuite { val tableName = "test_native_commands" assertResult(0) { - javaHiveCtx.hql(s"DROP TABLE IF EXISTS $tableName").count() + javaHiveCtx.sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerTempTable("show_tables") + javaHiveCtx.sql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx - .hql("SELECT result FROM show_tables") + .sql("SELECT result FROM show_tables") .collect() .map(_.getString(0)) .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerTempTable("describe_table") + javaHiveCtx.sql(s"DESCRIBE $tableName").registerTempTable("describe_table") + javaHiveCtx - .hql("SELECT result FROM describe_table") + .sql("SELECT result FROM describe_table") .collect() .map(_.getString(0).split("\t").map(_.trim)) .toArray } - assert(isExplanation(javaHiveCtx.hql( + assert(isExplanation(javaHiveCtx.sql( s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() @@ -90,7 +91,7 @@ class JavaHiveQLSuite extends FunSuite { ignore("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(TestHive.table(tableName)).isSuccess) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 83cfbc6b4a002..0ebaf6ffd5458 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -241,13 +241,13 @@ abstract class HiveComparisonTest val quotes = "\"\"\"" queryList.zipWithIndex.map { case (query, i) => - s"""val q$i = hql($quotes$query$quotes); q$i.collect()""" + s"""val q$i = sql($quotes$query$quotes); q$i.collect()""" }.mkString("\n== Console version of this test ==\n", "\n", "\n") } try { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") if (reset) { TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4ed41550cf530..aa810a291231a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -57,8 +57,8 @@ class HiveQuerySuite extends HiveComparisonTest { """.stripMargin) test("CREATE TABLE AS runs once") { - hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() - assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + sql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(sql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, "Incorrect number of rows in created table") } @@ -72,12 +72,14 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { + set("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) + set("spark.sql.dialect", "hiveql") + } test("Query expressed in HiveQL") { - hql("FROM src SELECT key").collect() - hiveql("FROM src SELECT key").collect() + sql("FROM src SELECT key").collect() } createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", @@ -193,12 +195,12 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") test("sampling") { - hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } test("SchemaRDD toString") { - hql("SHOW TABLES").toString - hql("SELECT * FROM src").toString + sql("SHOW TABLES").toString + sql("SELECT * FROM src").toString } createQueryTest("case statements with key #1", @@ -226,8 +228,8 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = hql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = sql("SELECT key FROM src").collect().toSet assert(actual === expected) } @@ -235,7 +237,7 @@ class HiveQuerySuite extends HiveComparisonTest { // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { intercept[Exception] { - hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() } } @@ -250,7 +252,7 @@ class HiveQuerySuite extends HiveComparisonTest { testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { - hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } } @@ -261,9 +263,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-1704: Explain commands as a SchemaRDD") { - hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = hql("explain select key, count(value) from src group by key") + val rdd = sql("explain select key, count(value) from src group by key") assert(isExplanation(rdd)) TestHive.reset() @@ -274,7 +276,7 @@ class HiveQuerySuite extends HiveComparisonTest { .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = - hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() .map(x => Pair(x.getString(0), x.getInt(1))) @@ -283,39 +285,39 @@ class HiveQuerySuite extends HiveComparisonTest { } test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { - hql("select key, count(*) c from src group by key having c").collect() + sql("select key, count(*) c from src group by key having c").collect() } test("SPARK-2225: turn HAVING without GROUP BY into a simple filter") { - assert(hql("select key from src having key > 490").collect().size < 100) + assert(sql("select key from src having key > 490").collect().size < 100) } test("Query Hive native command execution result") { val tableName = "test_native_commands" assertResult(0) { - hql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP TABLE IF EXISTS $tableName").count() } assertResult(0) { - hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } assert( - hql("SHOW TABLES") + sql("SHOW TABLES") .select('result) .collect() .map(_.getString(0)) .contains(tableName)) - assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() } test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" - val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") + val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)") // If the table was not created, the following assertion would fail assert(Try(table(tableName)).isSuccess) @@ -325,9 +327,9 @@ class HiveQuerySuite extends HiveComparisonTest { } test("DESCRIBE commands") { - hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - hql( + sql( """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') |SELECT key, value """.stripMargin) @@ -342,7 +344,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE test_describe_commands1") + sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } @@ -357,14 +359,14 @@ class HiveQuerySuite extends HiveComparisonTest { Array("# col_name", "data_type", "comment"), Array("dt", "string", null)) ) { - hql("DESCRIBE default.test_describe_commands1") + sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) .collect() } // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE test_describe_commands1 value") + sql("DESCRIBE test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -372,7 +374,7 @@ class HiveQuerySuite extends HiveComparisonTest { // Describe a column is a native command assertResult(Array(Array("value", "string", "from deserializer"))) { - hql("DESCRIBE default.test_describe_commands1 value") + sql("DESCRIBE default.test_describe_commands1 value") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -390,7 +392,7 @@ class HiveQuerySuite extends HiveComparisonTest { Array("", "", ""), Array("dt", "string", "None")) ) { - hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + sql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") .select('result) .collect() .map(_.getString(0).split("\t").map(_.trim)) @@ -409,16 +411,16 @@ class HiveQuerySuite extends HiveComparisonTest { Array("a", "IntegerType", null), Array("b", "StringType", null)) ) { - hql("DESCRIBE test_describe_commands2") + sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) .collect() } } test("SPARK-2263: Insert Map values") { - hql("CREATE TABLE m(value MAP)") - hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { + sql("CREATE TABLE m(value MAP)") + sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -430,18 +432,18 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "val0,val_1,val2.3,my_table" - hql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set some.property=20") + sql("set some.property=20") assert(get("some.property", "0") == "20") - hql("set some.property = 40") + sql("set some.property = 40") assert(get("some.property", "0") == "40") - hql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) - hql(s"set $testKey=") + sql(s"set $testKey=") assert(get(testKey, "0") == "") } @@ -454,33 +456,34 @@ class HiveQuerySuite extends HiveComparisonTest { clear() // "set" itself returns all config variables currently specified in SQLConf. - assert(hql("SET").collect().size == 0) + // TODO: Should we be listing the default here always? probably... + assert(sql("SET").collect().size == 0) assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } - hql(s"SET ${testKey + testKey}=${testVal + testVal}") + sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { - hql(s"SET").collect().map(_.getString(0)) + sql(s"SET").collect().map(_.getString(0)) } // "set key" assertResult(Array(s"$testKey=$testVal")) { - hql(s"SET $testKey").collect().map(_.getString(0)) + sql(s"SET $testKey").collect().map(_.getString(0)) } assertResult(Array(s"$nonexistentKey=")) { - hql(s"SET $nonexistentKey").collect().map(_.getString(0)) + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } - // Assert that sql() should have the same effects as hql() by repeating the above using sql(). + // Assert that sql() should have the same effects as sql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 2455c18925dfa..6b3ffd1c0ffe2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -56,13 +56,13 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerTempTable("caseSensitivityTest") - hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") - assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + .registerTempTable("nestedRepeatedTest") + assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 7436de264a1e1..c3c18cf8ccac3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,7 +35,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f944d010660eb..b6b8592344ef5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject */ class HiveUdfSuite extends HiveComparisonTest { - TestHive.hql( + TestHive.sql( """ |CREATE EXTERNAL TABLE hiveUdfTestTable ( | pair STRUCT @@ -48,16 +48,16 @@ class HiveUdfSuite extends HiveComparisonTest { """.stripMargin.format(classOf[PairSerDe].getName) ) - TestHive.hql( + TestHive.sql( "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) ) - TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) - TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 34d8a061ccc83..1a6dbc0ce0c0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConversions._ */ class PruningSuite extends HiveComparisonTest { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.hql("SHOW TABLES") + TestHive.sql("SHOW TABLES") // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 6545e8d7dcb69..6f57fe8958387 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -68,39 +68,40 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft .saveAsParquetFile(tempFile.getCanonicalPath) parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") - hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) - hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + sql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) + sql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) } test("SELECT on Parquet table") { - val rdd = hql("SELECT * FROM testsource").collect() + val rdd = sql("SELECT * FROM testsource").collect() assert(rdd != null) assert(rdd.forall(_.size == 6)) } test("Simple column projection + filter on Parquet table") { - val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() + val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() assert(rdd.size === 5, "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } test("Converting Hive to Parquet Table via saveAsParquetFile") { - hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) + sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") - val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) - val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0)) + val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0)) + compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) } test("INSERT OVERWRITE TABLE Parquet table") { - hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) + sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() - val rddCopy = hql("SELECT * FROM ptable").collect() - val rddOrig = hql("SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + val rddCopy = sql("SELECT * FROM ptable").collect() + val rddOrig = sql("SELECT * FROM testsource").collect() assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } From ac33cbbf33bd1ab29bc8165c9be02fb8934b1fdf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 3 Aug 2014 12:34:46 -0700 Subject: [PATCH 133/170] [SPARK-2814][SQL] HiveThriftServer2 throws NPE when executing native commands JIRA issue: [SPARK-2814](https://issues.apache.org/jira/browse/SPARK-2814) Author: Cheng Lian Closes #1753 from liancheng/spark-2814 and squashes the following commits: c74a3b2 [Cheng Lian] Fixed SPARK-2814 --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7db0159512610..acad681f68b14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -146,13 +146,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - - ss.err = new PrintStream(outputBuffer, true, "UTF-8") - ss.out = new PrintStream(outputBuffer, true, "UTF-8") - ss } + sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") + sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + override def set(key: String, value: String): Unit = { super.set(key, value) runSqlHive(s"SET $key=$value") From e139e2be60ef23281327744e1b3e74904dfdf63f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 3 Aug 2014 14:54:41 -0700 Subject: [PATCH 134/170] [SPARK-2783][SQL] Basic support for analyze in HiveContext JIRA: https://issues.apache.org/jira/browse/SPARK-2783 Author: Yin Huai Closes #1741 from yhuai/analyzeTable and squashes the following commits: 7bb5f02 [Yin Huai] Use sql instead of hql. 4d09325 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable e3ebcd4 [Yin Huai] Renaming. c170f4e [Yin Huai] Do not use getContentSummary. 62393b6 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable db233a6 [Yin Huai] Trying to debug jenkins... fee84f0 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable f0501f3 [Yin Huai] Fix compilation error. 24ad391 [Yin Huai] Merge remote-tracking branch 'upstream/master' into analyzeTable 8918140 [Yin Huai] Wording. 23df227 [Yin Huai] Add a simple analyze method to get the size of a table and update the "totalSize" property of this table in the Hive metastore. --- .../apache/spark/sql/hive/HiveContext.scala | 79 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../spark/sql/hive/StatisticsSuite.scala | 54 +++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index acad681f68b14..d8e7a5943daa5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -25,10 +25,14 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable import org.apache.spark.SparkContext @@ -107,6 +111,81 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } + /** + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ + def analyze(tableName: String) { + val relation = catalog.lookupRelation(None, tableName) match { + case LowerCaseSchema(r) => r + case o => o + } + + relation match { + case relation: MetastoreRelation => { + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDir) { + fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + } else { + fileStatus.getLen + } + + size + } + + def getFileSizeForTable(conf: HiveConf, table: Table): Long = { + val path = table.getPath() + var size: Long = 0L + try { + val fs = path.getFileSystem(conf) + size = calculateTableSize(fs, path) + } catch { + case e: Exception => + logWarning( + s"Failed to get the size of table ${table.getTableName} in the " + + s"database ${table.getDbName} because of ${e.toString}", e) + size = 0L + } + + size + } + + val tableParameters = relation.hiveQlTable.getParameters + val oldTotalSize = + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)).map(_.toLong).getOrElse(0L) + val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) + // Update the Hive metastore if the total size of the table is different than the size + // recorded in the Hive metastore. + // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + tableParameters.put(StatsSetupConst.TOTAL_SIZE, newTotalSize.toString) + val hiveTTable = relation.hiveQlTable.getTTable + hiveTTable.setParameters(tableParameters) + val tableFullName = + relation.hiveQlTable.getDbName() + "." + relation.hiveQlTable.getTableName() + + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } + } + case otherRelation => + throw new NotImplementedError( + s"Analyze has only implemented for Hive tables, " + + s"but ${tableName} is a ${otherRelation.nodeName}") + } + } + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient protected lazy val outputBuffer = new java.io.OutputStream { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index df3604439e483..301cf51c00e2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, Ser import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi @@ -278,9 +279,9 @@ private[hive] case class MetastoreRelation // relatively cheap if parameters for the table are populated into the metastore. An // alternative would be going through Hadoop's FileSystem API, which can be expensive if a lot // of RPCs are involved. Besides `totalSize`, there are also `numFiles`, `numRows`, - // `rawDataSize` keys that we can look at in the future. + // `rawDataSize` keys (see StatsSetupConst in Hive) that we can look at in the future. BigInt( - Option(hiveQlTable.getParameters.get("totalSize")) + Option(hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(sqlContext.defaultSizeInBytes)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index d8c77d6021d63..bf5931bbf97ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -26,6 +26,60 @@ import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { + test("analyze MetastoreRelations") { + def queryTotalSize(tableName: String): BigInt = + catalog.lookupRelation(None, tableName).statistics.sizeInBytes + + // Non-partitioned table + sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + + assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + + analyze("analyzeTable") + + assert(queryTotalSize("analyzeTable") === BigInt(11624)) + + sql("DROP TABLE analyzeTable").collect() + + // Partitioned table + sql( + """ + |CREATE TABLE analyzeTable_part (key STRING, value STRING) PARTITIONED BY (ds STRING) + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-01') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-02') + |SELECT * FROM src + """.stripMargin).collect() + sql( + """ + |INSERT INTO TABLE analyzeTable_part PARTITION (ds='2010-01-03') + |SELECT * FROM src + """.stripMargin).collect() + + assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + + analyze("analyzeTable_part") + + assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + + sql("DROP TABLE analyzeTable_part").collect() + + // Try to analyze a temp table + sql("""SELECT * FROM src""").registerTempTable("tempTable") + intercept[NotImplementedError] { + analyze("tempTable") + } + catalog.unregisterTable(None, "tempTable") + } + test("estimates the size of a test MetastoreRelation") { val rdd = sql("""SELECT * FROM src""") val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => From 55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 3 Aug 2014 15:52:00 -0700 Subject: [PATCH 135/170] [SPARK-1740] [PySpark] kill the python worker Kill only the python worker related to cancelled tasks. The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker. When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon. Author: Davies Liu Closes #1643 from davies/kill and squashes the following commits: 8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy 46ca150 [Davies Liu] address comment acd751c [Davies Liu] kill the worker when task is canceled --- .../scala/org/apache/spark/SparkEnv.scala | 5 +- .../apache/spark/api/python/PythonRDD.scala | 9 ++- .../api/python/PythonWorkerFactory.scala | 64 ++++++++++++++----- python/pyspark/daemon.py | 24 +++++-- python/pyspark/tests.py | 51 +++++++++++++++ 5 files changed, 125 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 92c809d854167..0bce531aaba3e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.Socket import scala.collection.JavaConversions._ import scala.collection.mutable @@ -102,10 +103,10 @@ class SparkEnv ( } private[spark] - def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) { + def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { synchronized { val key = (pythonExec, envVars) - pythonWorkers(key).stop() + pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fe9a9e50ef21d..0b5322c6fb965 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,8 +62,8 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - val worker: Socket = env.createPythonWorker(pythonExec, - envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) + envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) @@ -241,7 +241,7 @@ private[spark] class PythonRDD( if (!context.completed) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) + env.destroyPythonWorker(pythonExec, envVars.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging { /** * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - * This function is outdated, PySpark does not use it anymore */ - @deprecated + @deprecated("PySpark does not use it anymore", "1.1") def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 15fe8a9be6bfe..7af260d0b7f26 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, InputStream, OutputStreamWriter} +import java.lang.Runtime +import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 + var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + + var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, @@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. */ private def createThroughDaemon(): Socket = { + + def createSocket(): Socket = { + val socket = new Socket(daemonHost, daemonPort) + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python daemon failed to launch worker") + } + daemonWorkers.put(socket, pid) + socket + } + synchronized { // Start the daemon if it hasn't been started startDaemon() // Attempt to connect, restart and retry once if it fails try { - val socket = new Socket(daemonHost, daemonPort) - val launchStatus = new DataInputStream(socket.getInputStream).readInt() - if (launchStatus != 0) { - throw new IllegalStateException("Python daemon failed to launch worker") - } - socket + createSocket() } catch { case exc: SocketException => logWarning("Failed to open socket to Python daemon:", exc) logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - new Socket(daemonHost, daemonPort) + createSocket() } } } @@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Wait for it to connect to our socket serverSocket.setSoTimeout(10000) try { - return serverSocket.accept() + val socket = serverSocket.accept() + simpleWorkers.put(socket, worker) + return socket } catch { case e: Exception => throw new SparkException("Python worker did not connect back in time", e) @@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def stopDaemon() { synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() - } + if (useDaemon) { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } - daemon = null - daemonPort = 0 + daemon = null + daemonPort = 0 + } else { + simpleWorkers.mapValues(_.destroy()) + } } } def stop() { stopDaemon() } + + def stopWorker(worker: Socket) { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } + } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) + } + worker.close() + } } private object PythonWorkerFactory { diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 9fde0dde0f4b4..b00da833d06f1 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -26,7 +26,7 @@ from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main -from pyspark.serializers import write_int +from pyspark.serializers import read_int, write_int def compute_real_exit_code(exit_code): @@ -67,7 +67,8 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - write_int(0, outfile) # Acknowledge that the fork was successful + # Acknowledge that the fork was successful + write_int(os.getpid(), outfile) outfile.flush() worker_main(infile, outfile) except SystemExit as exc: @@ -125,14 +126,23 @@ def handle_sigchld(*args): else: raise if 0 in ready_fds: - # Spark told us to exit by closing stdin - shutdown(0) + try: + worker_pid = read_int(sys.stdin) + except EOFError: + # Spark told us to exit by closing stdin + shutdown(0) + try: + os.kill(worker_pid, signal.SIGKILL) + except OSError: + pass # process already died + + if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process try: - fork_return_code = os.fork() - if fork_return_code == 0: + pid = os.fork() + if pid == 0: listen_sock.close() try: worker(sock) @@ -143,11 +153,13 @@ def handle_sigchld(*args): os._exit(0) else: sock.close() + except OSError as e: print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) write_int(-1, outfile) # Signal that the fork failed outfile.flush() + outfile.close() sock.close() finally: shutdown(1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 16fb5a9256220..acc3c30371621 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -790,6 +790,57 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) +class TestWorker(PySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + def sleep(x): + import os, time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + self.sc.parallelize(range(1)).foreach(sleep) + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + data = open(path).read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + def test_fd_leak(self): + N = 1100 # fd limit is 1024 by default + rdd = self.sc.parallelize(range(N), N) + self.assertEquals(N, rdd.count()) + + class TestSparkSubmit(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() From 6ba6c3ebfe9a47351a50e45271e241140b09bf10 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Sun, 3 Aug 2014 17:47:49 -0700 Subject: [PATCH 136/170] [SPARK-2810] upgrade to scala-maven-plugin 3.2.0 Needed for Scala 2.11 compiler-interface Signed-off-by: Anand Avati Author: Anand Avati Closes #1711 from avati/SPARK-1812-scala-maven-plugin and squashes the following commits: 9a22fc8 [Anand Avati] SPARK-1812: upgrade to scala-maven-plugin 3.2.0 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index cc9377cec2a07..4ab027bad55c0 100644 --- a/pom.xml +++ b/pom.xml @@ -782,7 +782,7 @@ net.alchim31.maven scala-maven-plugin - 3.1.6 + 3.2.0 scala-compile-first From 5507dd8e18fbb52d5e0c64a767103b2418cb09c6 Mon Sep 17 00:00:00 2001 From: Sarah Gerweck Date: Sun, 3 Aug 2014 19:47:05 -0700 Subject: [PATCH 137/170] Fix some bugs with spaces in directory name. Any time you use the directory name (`FWDIR`) it needs to be surrounded in quotes. If you're also using wildcards, you can safely put the quotes around just `$FWDIR`. Author: Sarah Gerweck Closes #1756 from sarahgerweck/folderSpaces and squashes the following commits: 732629d [Sarah Gerweck] Fix some bugs with spaces in directory name. --- make-distribution.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index 1441497b3995a..f7a6a9d838bb6 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -168,22 +168,22 @@ mkdir -p "$DISTDIR/lib" echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" # Copy jars -cp $FWDIR/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" -cp $FWDIR/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" +cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r $FWDIR/examples/src/main "$DISTDIR/examples/src/" +cp -r "$FWDIR"/examples/src/main "$DISTDIR/examples/src/" if [ "$SPARK_HIVE" == "true" ]; then - cp $FWDIR/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" + cp "$FWDIR"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" fi # Copy license and ASF files cp "$FWDIR/LICENSE" "$DISTDIR" cp "$FWDIR/NOTICE" "$DISTDIR" -if [ -e $FWDIR/CHANGES.txt ]; then +if [ -e "$FWDIR"/CHANGES.txt ]; then cp "$FWDIR/CHANGES.txt" "$DISTDIR" fi From ae58aea2d1435b5bb011e68127e1bcddc2edf5b2 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 3 Aug 2014 21:39:21 -0700 Subject: [PATCH 138/170] SPARK-2272 [MLlib] Feature scaling which standardizes the range of independent variables or features of data Feature scaling is a method used to standardize the range of independent variables or features of data. In data processing, it is generally performed during the data preprocessing step. In this work, a trait called `VectorTransformer` is defined for generic transformation on a vector. It contains one method to be implemented, `transform` which applies transformation on a vector. There are two implementations of `VectorTransformer` now, and they all can be easily extended with PMML transformation support. 1) `StandardScaler` - Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. 2) `Normalizer` - Normalizes samples individually to unit L^n norm Author: DB Tsai Closes #1207 from dbtsai/dbtsai-feature-scaling and squashes the following commits: 78c15d3 [DB Tsai] Alpine Data Labs --- .../spark/mllib/feature/Normalizer.scala | 76 +++++++ .../spark/mllib/feature/StandardScaler.scala | 119 +++++++++++ .../mllib/feature/VectorTransformer.scala | 51 +++++ .../mllib/linalg/distributed/RowMatrix.scala | 2 +- .../spark/mllib/feature/NormalizerSuite.scala | 120 +++++++++++ .../mllib/feature/StandardScalerSuite.scala | 200 ++++++++++++++++++ 6 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala new file mode 100644 index 0000000000000..ea9fd0a80d8e0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * :: DeveloperApi :: + * Normalizes samples individually to unit L^p^ norm + * + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * sum(abs(vector).^p^)^(1/p)^ as norm. + * + * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. + * + * @param p Normalization in L^p^ space, p = 2 by default. + */ +@DeveloperApi +class Normalizer(p: Double) extends VectorTransformer { + + def this() = this(2) + + require(p >= 1.0) + + /** + * Applies unit length normalization on a vector. + * + * @param vector vector to be normalized. + * @return normalized vector. If the norm of the input is zero, it will return the input vector. + */ + override def transform(vector: Vector): Vector = { + var norm = vector.toBreeze.norm(p) + + if (norm != 0.0) { + // For dense vector, we've to allocate new memory for new output vector. + // However, for sparse vector, the `index` array will not be changed, + // so we can re-use it to save memory. + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) + case sv: BSV[Double] => + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) /= norm + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Since the norm is zero, return the input vector object itself. + // Note that it's safe since we always assume that the data in RDD + // should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala new file mode 100644 index 0000000000000..cc2d7579c2901 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + * + * @param withMean False by default. Centers the data with mean before scaling. It will build a + * dense output, so this does not work on sparse input and will raise an exception. + * @param withStd True by default. Scales the data to unit standard deviation. + */ +@DeveloperApi +class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { + + def this() = this(false, true) + + require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") + + private var mean: BV[Double] = _ + private var factor: BV[Double] = _ + + /** + * Computes the mean and variance and stores as a model to be used for later scaling. + * + * @param data The data used to compute the mean and variance to build the transformation model. + * @return This StandardScalar object. + */ + def fit(data: RDD[Vector]): this.type = { + val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + + mean = summary.mean.toBreeze + factor = summary.variance.toBreeze + require(mean.length == factor.length) + + var i = 0 + while (i < factor.length) { + factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + i += 1 + } + + this + } + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` + * for the column with zero variance. + */ + override def transform(vector: Vector): Vector = { + if (mean == null || factor == null) { + throw new IllegalStateException( + "Haven't learned column summary statistics yet. Call fit first.") + } + + require(vector.size == mean.length) + + if (withMean) { + vector.toBreeze match { + case dv: BDV[Double] => + val output = vector.toBreeze.copy + var i = 0 + while (i < output.length) { + output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else if (withStd) { + vector.toBreeze match { + case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) + case sv: BSV[Double] => + // For sparse vector, the `index` array inside sparse vector object will not be changed, + // so we can re-use it to save memory. + val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + var i = 0 + while (i < output.data.length) { + output.data(i) *= factor(output.index(i)) + i += 1 + } + Vectors.fromBreeze(output) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + // Note that it's safe since we always assume that the data in RDD should be immutable. + vector + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala new file mode 100644 index 0000000000000..415a845332d45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for transformation of a vector + */ +@DeveloperApi +trait VectorTransformer extends Serializable { + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + def transform(vector: Vector): Vector + + /** + * Applies transformation on an RDD[Vector]. + * + * @param data RDD[Vector] to be transformed. + * @return transformed RDD[Vector]. + */ + def transform(data: RDD[Vector]): RDD[Vector] = { + // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // So it should be no longer necessary to explicitly broadcast `this` object. + data.map(x => this.transform(x)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 58c1322757a43..45486b2c7d82d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..fb76dccfdf79e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class NormalizerSuite extends FunSuite with LocalSparkContext { + + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + + lazy val dataRDD = sc.parallelize(data, 3) + + test("Normalization using L1 distance") { + val l1Normalizer = new Normalizer(1) + + val data1 = data.map(l1Normalizer.transform) + val data1RDD = l1Normalizer.transform(dataRDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + + assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) + assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data1(2) ~== Vectors.dense(0.12765957, -0.23404255, -0.63829787) absTol 1E-5) + assert(data1(3) ~== Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))) absTol 1E-5) + assert(data1(4) ~== Vectors.dense(0.625, 0.07894737, 0.29605263) absTol 1E-5) + assert(data1(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L2 distance") { + val l2Normalizer = new Normalizer() + + val data2 = data.map(l2Normalizer.transform) + val data2RDD = l2Normalizer.transform(dataRDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + + assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) + assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(data2(2) ~== Vectors.dense(0.184549876, -0.3383414, -0.922749378) absTol 1E-5) + assert(data2(3) ~== Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.897906166, 0.113419726, 0.42532397) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + + test("Normalization using L^Inf distance.") { + val lInfNormalizer = new Normalizer(Double.PositiveInfinity) + + val dataInf = data.map(lInfNormalizer.transform) + val dataInfRDD = lInfNormalizer.transform(dataRDD) + + assert((data, dataInf, dataInfRDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + + assert((dataInf, dataInfRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(dataInf(0).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(2).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(3).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + assert(dataInf(4).toArray.map(Math.abs).max ~== 1.0 absTol 1E-5) + + assert(dataInf(0) ~== Vectors.sparse(3, Seq((0, -0.86956522), (1, 1.0))) absTol 1E-5) + assert(dataInf(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(dataInf(2) ~== Vectors.dense(0.2, -0.36666667, -1.0) absTol 1E-5) + assert(dataInf(3) ~== Vectors.sparse(3, Seq((1, 0.284375), (2, 1.0))) absTol 1E-5) + assert(dataInf(4) ~== Vectors.dense(1.0, 0.12631579, 0.473684211) absTol 1E-5) + assert(dataInf(5) ~== Vectors.sparse(3, Seq()) absTol 1E-5) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000000..5a9be923a8625 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.rdd.RDD + +class StandardScalerSuite extends FunSuite with LocalSparkContext { + + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { + data.treeAggregate(new MultivariateOnlineSummarizer)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + } + + test("Standardization with dense input") { + val data = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + withClue("Using a standardizer before fitting the model should throw exception.") { + intercept[IllegalStateException] { + data.map(standardizer1.transform) + } + } + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + val data1RDD = standardizer1.transform(dataRDD) + val data2RDD = standardizer2.transform(dataRDD) + val data3RDD = standardizer3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((data, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + + + test("Standardization with sparse input") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val dataRDD = sc.parallelize(data, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data2 = data.map(standardizer2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + data.map(standardizer3.transform) + } + } + + val data2RDD = standardizer2.transform(dataRDD) + + val summary2 = computeSummary(data2RDD) + + assert((data, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + + test("Standardization with constant input") { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val data = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + standardizer1.fit(dataRDD) + standardizer2.fit(dataRDD) + standardizer3.fit(dataRDD) + + val data1 = data.map(standardizer1.transform) + val data2 = data.map(standardizer2.transform) + val data3 = data.map(standardizer3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + +} From e053c55819363fab7068bb9165e3379f0c2f570c Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 23:55:58 -0700 Subject: [PATCH 139/170] [MLlib] [SPARK-2510]Word2Vec: Distributed Representation of Words This is a pull request regarding SPARK-2510 at https://issues.apache.org/jira/browse/SPARK-2510. Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary. The vector representation can be used as features in natural language processing and machine learning algorithms. To make our implementation more scalable, we train each partition separately and merge the model of each partition after each iteration. To make the model more accurate, multiple iterations may be needed. To investigate the vector representations is to find the closest words for a query word. For example, the top 20 closest words to "china" are for 1 partition and 1 iteration : taiwan 0.8077646146334014 korea 0.740913304563621 japan 0.7240667798885471 republic 0.7107151279078352 thailand 0.6953217332072862 tibet 0.6916782118129544 mongolia 0.6800858715972612 macau 0.6794925677480378 singapore 0.6594048695593799 manchuria 0.658989931844148 laos 0.6512978726001666 nepal 0.6380792327845325 mainland 0.6365469459587788 myanmar 0.6358614338840394 macedonia 0.6322366180313249 xinjiang 0.6285291551708028 russia 0.6279951236068411 india 0.6272874944023487 shanghai 0.6234544135576999 macao 0.6220588462925876 The result with 10 partitions and 5 iterations is: taiwan 0.8310495079388313 india 0.7737171315919039 japan 0.756777901233668 korea 0.7429767187102452 indonesia 0.7407557427278356 pakistan 0.712883426985585 mainland 0.7053379963140822 thailand 0.696298191073948 mongolia 0.693690656871415 laos 0.6913069680735292 macau 0.6903427690029617 republic 0.6766381604813666 malaysia 0.676460699141784 singapore 0.6728790997360923 malaya 0.672345232966194 manchuria 0.6703732292753156 macedonia 0.6637955686322028 myanmar 0.6589462882439646 kazakhstan 0.657017801081494 cambodia 0.6542383836451932 Author: Liquan Pei Author: Xiangrui Meng Author: Liquan Pei Closes #1719 from Ishiihara/master and squashes the following commits: 2ba9483 [Liquan Pei] minor fix for Word2Vec test e248441 [Liquan Pei] minor style change 26a948d [Liquan Pei] Merge pull request #1 from mengxr/Ishiihara-master c14da41 [Xiangrui Meng] fix styles 384c771 [Xiangrui Meng] remove minCount and window from constructor change model to use float instead of double e93e726 [Liquan Pei] use treeAggregate instead of aggregate 1a8fb41 [Liquan Pei] use weighted sum in combOp 7efbb6f [Liquan Pei] use broadcast version of vocab in aggregate 6bcc8be [Liquan Pei] add multiple iteration support 720b5a3 [Liquan Pei] Add test for Word2Vec algorithm, minor fixes 2e92b59 [Liquan Pei] modify according to feedback 57dc50d [Liquan Pei] code formatting e4a04d3 [Liquan Pei] minor fix 0aafb1b [Liquan Pei] Add comments, minor fixes 8d6befe [Liquan Pei] initial commit --- .../apache/spark/mllib/feature/Word2Vec.scala | 424 ++++++++++++++++++ .../spark/mllib/feature/Word2VecSuite.scala | 61 +++ 2 files changed, 485 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala new file mode 100644 index 0000000000000..87c81e7b0bd2f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.{HashPartitioner, Logging} +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + +/** + * Entry in vocabulary + */ +private case class VocabWord( + var word: String, + var cn: Int, + var point: Array[Int], + var code: Array[Int], + var codeLen:Int +) + +/** + * :: Experimental :: + * Word2Vec creates vector representation of words in a text corpus. + * The algorithm first constructs a vocabulary from the corpus + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in + * natural language processing and machine learning algorithms. + * + * We used skip-gram model in our implementation and hierarchical softmax + * method to train the model. The variable names in the implementation + * matches the original C implementation. + * + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see + * Efficient Estimation of Word Representations in Vector Space + * and + * Distributed Representations of Words and Phrases and their Compositionality. + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations to run, should be smaller than or equal to parallelism + */ +@Experimental +class Word2Vec( + val size: Int, + val startingAlpha: Double, + val parallelism: Int, + val numIterations: Int) extends Serializable with Logging { + + /** + * Word2Vec with a single thread. + */ + def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + + private val EXP_TABLE_SIZE = 1000 + private val MAX_EXP = 6 + private val MAX_CODE_LENGTH = 40 + private val MAX_SENTENCE_LENGTH = 1000 + private val layer1Size = size + private val modelPartitionNum = 100 + + /** context words from [-window, window] */ + private val window = 5 + + /** minimum frequency to consider a vocabulary word */ + private val minCount = 5 + + private var trainWordsCount = 0 + private var vocabSize = 0 + private var vocab: Array[VocabWord] = null + private var vocabHash = mutable.HashMap.empty[String, Int] + private var alpha = startingAlpha + + private def learnVocab(words:RDD[String]): Unit = { + vocab = words.map(w => (w, 1)) + .reduceByKey(_ + _) + .map(x => VocabWord( + x._1, + x._2, + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + 0)) + .filter(_.cn >= minCount) + .collect() + .sortWith((a, b) => a.cn > b.cn) + + vocabSize = vocab.length + var a = 0 + while (a < vocabSize) { + vocabHash += vocab(a).word -> a + trainWordsCount += vocab(a).cn + a += 1 + } + logInfo("trainWordsCount = " + trainWordsCount) + } + + private def createExpTable(): Array[Float] = { + val expTable = new Array[Float](EXP_TABLE_SIZE) + var i = 0 + while (i < EXP_TABLE_SIZE) { + val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) + expTable(i) = (tmp / (tmp + 1.0)).toFloat + i += 1 + } + expTable + } + + private def createBinaryTree(): Unit = { + val count = new Array[Long](vocabSize * 2 + 1) + val binary = new Array[Int](vocabSize * 2 + 1) + val parentNode = new Array[Int](vocabSize * 2 + 1) + val code = new Array[Int](MAX_CODE_LENGTH) + val point = new Array[Int](MAX_CODE_LENGTH) + var a = 0 + while (a < vocabSize) { + count(a) = vocab(a).cn + a += 1 + } + while (a < 2 * vocabSize) { + count(a) = 1e9.toInt + a += 1 + } + var pos1 = vocabSize - 1 + var pos2 = vocabSize + + var min1i = 0 + var min2i = 0 + + a = 0 + while (a < vocabSize - 1) { + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min1i = pos1 + pos1 -= 1 + } else { + min1i = pos2 + pos2 += 1 + } + } else { + min1i = pos2 + pos2 += 1 + } + if (pos1 >= 0) { + if (count(pos1) < count(pos2)) { + min2i = pos1 + pos1 -= 1 + } else { + min2i = pos2 + pos2 += 1 + } + } else { + min2i = pos2 + pos2 += 1 + } + count(vocabSize + a) = count(min1i) + count(min2i) + parentNode(min1i) = vocabSize + a + parentNode(min2i) = vocabSize + a + binary(min2i) = 1 + a += 1 + } + // Now assign binary code to each vocabulary word + var i = 0 + a = 0 + while (a < vocabSize) { + var b = a + i = 0 + while (b != vocabSize * 2 - 2) { + code(i) = binary(b) + point(i) = b + i += 1 + b = parentNode(b) + } + vocab(a).codeLen = i + vocab(a).point(0) = vocabSize - 2 + b = 0 + while (b < i) { + vocab(a).code(i - b - 1) = code(b) + vocab(a).point(i - b) = point(b) - vocabSize + b += 1 + } + a += 1 + } + } + + /** + * Computes the vector representation of each word in vocabulary. + * @param dataset an RDD of words + * @return a Word2VecModel + */ + def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { + + val words = dataset.flatMap(x => x) + + learnVocab(words) + + createBinaryTree() + + val sc = dataset.context + + val expTable = sc.broadcast(createExpTable()) + val bcVocab = sc.broadcast(vocab) + val bcVocabHash = sc.broadcast(vocabHash) + + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => + new Iterator[Array[Int]] { + def hasNext: Boolean = iter.hasNext + + def next(): Array[Int] = { + var sentence = new ArrayBuffer[Int] + var sentenceLength = 0 + while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { + val word = bcVocabHash.value.get(iter.next()) + word match { + case Some(w) => + sentence += w + sentenceLength += 1 + case None => + } + } + sentence.toArray + } + } + } + + val newSentences = sentences.repartition(parallelism).cache() + var syn0Global = + Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + var syn1Global = new Array[Float](vocabSize * layer1Size) + + for(iter <- 1 to numIterations) { + val (aggSyn0, aggSyn1, _, _) = + // TODO: broadcast temp instead of serializing it directly + // or initialize the model in each executor + newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( + seqOp = (c, v) => (c, v) match { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => + var lwc = lastWordCount + var wc = wordCount + if (wordCount - lastWordCount > 10000) { + lwc = wordCount + alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + } + wc += sentence.size + var pos = 0 + while (pos < sentence.size) { + val word = sentence(pos) + // TODO: fix random seed + val b = Random.nextInt(window) + // Train Skip-gram + var a = b + while (a < window * 2 + 1 - b) { + if (a != window) { + val c = pos - window + a + if (c >= 0 && c < sentence.size) { + val lastWord = sentence(c) + val l1 = lastWord * layer1Size + val neu1e = new Array[Float](layer1Size) + // Hierarchical softmax + var d = 0 + while (d < bcVocab.value(word).codeLen) { + val l2 = bcVocab.value(word).point(d) * layer1Size + // Propagate hidden -> output + var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + if (f > -MAX_EXP && f < MAX_EXP) { + val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt + f = expTable.value(ind) + val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat + blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + } + d += 1 + } + blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + } + } + a += 1 + } + pos += 1 + } + (syn0, syn1, lwc, wc) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + }) + syn0Global = aggSyn0 + syn1Global = aggSyn1 + } + newSentences.unpersist() + + val wordMap = new Array[(String, Array[Float])](vocabSize) + var i = 0 + while (i < vocabSize) { + val word = bcVocab.value(i).word + val vector = new Array[Float](layer1Size) + Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + wordMap(i) = (word, vector) + i += 1 + } + val modelRDD = sc.parallelize(wordMap, modelPartitionNum) + .partitionBy(new HashPartitioner(modelPartitionNum)) + .persist(StorageLevel.MEMORY_AND_DISK) + + new Word2VecModel(modelRDD) + } +} + +/** +* Word2Vec model +*/ +class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + + private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { + require(v1.length == v2.length, "Vectors should have the same length") + val n = v1.length + val norm1 = blas.snrm2(n, v1, 1) + val norm2 = blas.snrm2(n, v2, 1) + if (norm1 == 0 || norm2 == 0) return 0.0 + blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + } + + /** + * Transforms a word to its vector representation + * @param word a word + * @return vector representation of word + */ + def transform(word: String): Vector = { + val result = model.lookup(word) + if (result.isEmpty) { + throw new IllegalStateException(s"$word not in vocabulary") + } + else Vectors.dense(result(0).map(_.toDouble)) + } + + /** + * Transforms an RDD to its vector representation + * @param dataset a an RDD of words + * @return RDD of vector representation + */ + def transform(dataset: RDD[String]): RDD[Vector] = { + dataset.map(word => transform(word)) + } + + /** + * Find synonyms of a word + * @param word a word + * @param num number of synonyms to find + * @return array of (word, similarity) + */ + def findSynonyms(word: String, num: Int): Array[(String, Double)] = { + val vector = transform(word) + findSynonyms(vector,num) + } + + /** + * Find synonyms of the vector representation of a word + * @param vector vector representation of a word + * @param num number of synonyms to find + * @return array of (word, cosineSimilarity) + */ + def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + require(num > 0, "Number of similar words should > 0") + val topK = model.map { case(w, vec) => + (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } + .sortByKey(ascending = false) + .take(num + 1) + .map(_.swap) + .tail + + topK + } +} + +object Word2Vec{ + /** + * Train Word2Vec model + * @param input RDD of words + * @param size vector dimension + * @param startingAlpha initial learning rate + * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) + * @param numIterations number of iterations, should be smaller than or equal to parallelism + * @return Word2Vec model + */ + def train[S <: Iterable[String]]( + input: RDD[S], + size: Int, + startingAlpha: Double, + parallelism: Int = 1, + numIterations:Int = 1): Word2VecModel = { + new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..b5db39b68a223 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class Word2VecSuite extends FunSuite with LocalSparkContext { + + // TODO: add more tests + + test("Word2Vec") { + val sentence = "a b " * 100 + "a c " * 10 + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + val size = 10 + val startingAlpha = 0.025 + val window = 2 + val minCount = 2 + val num = 2 + + val model = Word2Vec.train(doc, size, startingAlpha) + val syms = model.findSynonyms("a", 2) + assert(syms.length == num) + assert(syms(0)._1 == "b") + assert(syms(1)._1 == "c") + } + + + test("Word2VecModel") { + val num = 2 + val localModel = Seq( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val syms = model.findSynonyms("china", num) + assert(syms.length == num) + assert(syms(0)._1 == "taiwan") + assert(syms(1)._1 == "japan") + } +} From 59f84a9531f7974a053fd4963ce9afd88273ea4c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Aug 2014 12:13:41 -0700 Subject: [PATCH 140/170] [SPARK-1687] [PySpark] pickable namedtuple Add an hook to replace original namedtuple with an pickable one, then namedtuple could be used in RDDs. PS: pyspark should be import BEFORE "from collections import namedtuple" Author: Davies Liu Closes #1623 from davies/namedtuple and squashes the following commits: 045dad8 [Davies Liu] remove unrelated code changes 4132f32 [Davies Liu] address comment 55b1c1a [Davies Liu] fix tests 61f86eb [Davies Liu] replace all the reference of namedtuple to new hacked one 98df6c6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple f7b1bde [Davies Liu] add hack for CloudPickleSerializer 0c5c849 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple 21991e6 [Davies Liu] hack namedtuple in __main__ module, make it picklable. 93b03b8 [Davies Liu] pickable namedtuple --- python/pyspark/serializers.py | 60 +++++++++++++++++++++++++++++++++++ python/pyspark/tests.py | 19 +++++++++++ 2 files changed, 79 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 03b31ae9624c2..1b52c144df087 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -65,6 +65,9 @@ import marshal import struct import sys +import types +import collections + from pyspark import cloudpickle @@ -267,6 +270,63 @@ def dumps(self, obj): return obj +# Hook namedtuple, make it picklable + +__cls = {} + + +def _restore(name, fields, value): + """ Restore an object of namedtuple""" + k = (name, fields) + cls = __cls.get(k) + if cls is None: + cls = collections.namedtuple(name, fields) + __cls[k] = cls + return cls(*value) + + +def _hack_namedtuple(cls): + """ Make class generated by namedtuple picklable """ + name = cls.__name__ + fields = cls._fields + def __reduce__(self): + return (_restore, (name, fields, tuple(self))) + cls.__reduce__ = __reduce__ + return cls + + +def _hijack_namedtuple(): + """ Hack namedtuple() to make it picklable """ + global _old_namedtuple # or it will put in closure + + def _copy_func(f): + return types.FunctionType(f.func_code, f.func_globals, f.func_name, + f.func_defaults, f.func_closure) + + _old_namedtuple = _copy_func(collections.namedtuple) + + def namedtuple(name, fields, verbose=False, rename=False): + cls = _old_namedtuple(name, fields, verbose, rename) + return _hack_namedtuple(cls) + + # replace namedtuple with new one + collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.func_code = namedtuple.func_code + + # hack the cls already generated by namedtuple + # those created in other module can be pickled as normal, + # so only hack those in __main__ module + for n, o in sys.modules["__main__"].__dict__.iteritems(): + if (type(o) is type and o.__base__ is tuple + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace + + +_hijack_namedtuple() + + class PickleSerializer(FramedSerializer): """ Serializes objects using Python's cPickle serializer: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index acc3c30371621..4ac94ba729d35 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -112,6 +112,17 @@ def test_huge_dataset(self): m._cleanup() +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from cPickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEquals(p1, p2) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -298,6 +309,14 @@ def test_itemgetter(self): self.assertEqual([1], rdd.map(itemgetter(1)).collect()) self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEquals([jon, jane], theDoes.collect()) + class TestIO(PySparkTestCase): From 8e7d5ba1a20a8a1f409e9d6472ae3e6c4bc948b4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 12:59:18 -0700 Subject: [PATCH 141/170] SPARK-2792. Fix reading too much or too little data from each stream in ExternalMap / Sorter All these changes are from mridulm's work in #1609, but extracted here to fix this specific issue and make it easier to merge not 1.1. This particular set of changes is to make sure that we read exactly the right range of bytes from each spill file in EAOM: some serializers can write bytes after the last object (e.g. the TC_RESET flag in Java serialization) and that would confuse the previous code into reading it as part of the next batch. There are also improvements to cleanup to make sure files are closed. In addition to bringing in the changes to ExternalAppendOnlyMap, I also copied them to the corresponding code in ExternalSorter and updated its test suite to test for the same issues. Author: Matei Zaharia Closes #1722 from mateiz/spark-2792 and squashes the following commits: 5d4bfb5 [Matei Zaharia] Make objectStreamReset counter count the last object written too 18fe865 [Matei Zaharia] Update docs on objectStreamReset 576ee83 [Matei Zaharia] Allow objectStreamReset to be 0 0374217 [Matei Zaharia] Remove super paranoid code to close file handles bda37bb [Matei Zaharia] Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too 0d6dad7 [Matei Zaharia] Added Mridul's test changes for ExternalAppendOnlyMap 9a78e4b [Matei Zaharia] Add @mridulm's fixes to ExternalAppendOnlyMap for batch sizes --- .../spark/serializer/JavaSerializer.scala | 5 +- .../collection/ExternalAppendOnlyMap.scala | 86 +++++++++++---- .../util/collection/ExternalSorter.scala | 104 +++++++++++++----- .../ExternalAppendOnlyMapSuite.scala | 33 ++++-- .../util/collection/ExternalSorterSuite.scala | 47 +++++--- docs/configuration.md | 2 +- 6 files changed, 194 insertions(+), 83 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index a7fa057ee05f7..34bc3124097bb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In /** * Calling reset to avoid memory leak: * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api - * But only call it every 10,000th time to avoid bloated serialization streams (when + * But only call it every 100th time to avoid bloated serialization streams (when * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) + counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() counter = 0 - } else { - counter += 1 } this } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cb67a1c039f20..5d10a1f84493c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException} +import java.io._ import java.util.Comparator import scala.collection.BufferedIterator @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator @@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables def flush() = { - writer.commitAndClose() - val bytesWritten = writer.bytesWritten + val w = writer + writer = null + w.commitAndClose() + val bytesWritten = w.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten objectsWritten = 0 } + var success = false try { val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { @@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer.close() writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) } } if (objectsWritten > 0) { flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() } + success = true } finally { - // Partial failures cannot be tolerated; do not revert partial writes - writer.close() + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } + } } currentMap = new SizeTrackingAppendOnlyMap[K, C] @@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C]( * An iterator that returns (K, C) pairs in sorted order from an on-disk map */ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) - extends Iterator[(K, C)] { - private val fileStream = new FileInputStream(file) - private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + extends Iterator[(K, C)] + { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(file.length() == batchOffsets(batchOffsets.length - 1)) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FileInputStream = null // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var batchStream = nextBatchStream() - private var compressedStream = blockManager.wrapForCompression(blockId, batchStream) - private var deserializeStream = ser.deserializeStream(compressedStream) + private var deserializeStream = nextBatchStream() private var nextItem: (K, C) = null private var objectsRead = 0 /** * Construct a stream that reads only from the next batch. */ - private def nextBatchStream(): InputStream = { - if (batchSizes.length > 0) { - ByteStreams.limit(bufferedStream, batchSizes.remove(0)) + private def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + ser.deserializeStream(compressedStream) } else { // No more batches left - bufferedStream + cleanup() + null } } @@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C]( val item = deserializeStream.readObject().asInstanceOf[(K, C)] objectsRead += 1 if (objectsRead == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(blockId, batchStream) - deserializeStream = ser.deserializeStream(compressedStream) objectsRead = 0 + deserializeStream = nextBatchStream() } item } catch { @@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { + if (deserializeStream == null) { + return false + } nextItem = readNextItem() } nextItem != null @@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C]( // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { - deserializeStream.close() + batchIndex = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() file.delete() } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6e415a2bd8ce2..b04c50bd3e196 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,7 +26,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.BlockId /** @@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. def flush() = { - writer.commitAndClose() - val bytesWritten = writer.bytesWritten + val w = writer + writer = null + w.commitAndClose() + val bytesWritten = w.bytesWritten batchSizes.append(bytesWritten) _diskBytesSpilled += bytesWritten objectsWritten = 0 } + var success = false try { val it = collection.destructiveSortedIterator(partitionKeyComparator) while (it.hasNext) { @@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C]( } if (objectsWritten > 0) { flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + file.delete() + } } - writer.close() - } catch { - case e: Exception => - writer.close() - file.delete() - throw e } if (usingMap) { @@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C]( * partitions to be requested in order. */ private[this] class SpillReader(spill: SpilledFile) { - val fileStream = new FileInputStream(spill.file) - val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) + // Serializer batch offsets; size will be batchSize.length + 1 + val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _) // Track which partition and which batch stream we're in. These will be the indices of // the next element we will read. We'll also store the last partition read so that // readNextPartition() can figure out what partition that was from. var partitionId = 0 var indexInPartition = 0L - var batchStreamsRead = 0 + var batchId = 0 var indexInBatch = 0 var lastPartitionId = 0 skipToNextPartition() - // An intermediate stream that reads from exactly one batch + + // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - var batchStream = nextBatchStream() - var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) - var deserStream = serInstance.deserializeStream(compressedStream) + var fileStream: FileInputStream = null + var deserializeStream = nextBatchStream() // Also sets fileStream + var nextItem: (K, C) = null var finished = false /** Construct a stream that only reads from the next batch */ - def nextBatchStream(): InputStream = { - if (batchStreamsRead < spill.serializerBatchSizes.length) { - batchStreamsRead += 1 - ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1)) + def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchId < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchId) + fileStream = new FileInputStream(spill.file) + fileStream.getChannel.position(start) + batchId += 1 + + val end = batchOffsets(batchId) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream) + serInstance.deserializeStream(compressedStream) } else { - // No more batches left; give an empty stream - bufferedStream + // No more batches left + cleanup() + null } } @@ -525,19 +560,17 @@ private[spark] class ExternalSorter[K, V, C]( * If no more pairs are left, return null. */ private def readNextItem(): (K, C) = { - if (finished) { + if (finished || deserializeStream == null) { return null } - val k = deserStream.readObject().asInstanceOf[K] - val c = deserStream.readObject().asInstanceOf[C] + val k = deserializeStream.readObject().asInstanceOf[K] + val c = deserializeStream.readObject().asInstanceOf[C] lastPartitionId = partitionId // Start reading the next batch if we're done with this one indexInBatch += 1 if (indexInBatch == serializerBatchSize) { - batchStream = nextBatchStream() - compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream) - deserStream = serInstance.deserializeStream(compressedStream) indexInBatch = 0 + deserializeStream = nextBatchStream() } // Update the partition location of the element we're reading indexInPartition += 1 @@ -545,7 +578,9 @@ private[spark] class ExternalSorter[K, V, C]( // If we've finished reading the last partition, remember that we're done if (partitionId == numPartitions) { finished = true - deserStream.close() + if (deserializeStream != null) { + deserializeStream.close() + } } (k, c) } @@ -578,6 +613,17 @@ private[spark] class ExternalSorter[K, V, C]( item } } + + // Clean up our open streams and put us in a state where we can't read any more data + def cleanup() { + batchId = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + ds.close() + // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop(). + // This should also be fixed in ExternalAppendOnlyMap. + } } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 7de5df6e1c8bd..04d7338488628 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -30,8 +30,19 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + private def createSparkConf(loadDefaults: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf + } + test("simple insert") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -57,7 +68,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("insert with collision") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -80,7 +91,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -125,7 +136,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("null keys and values") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, @@ -166,7 +177,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple aggregator") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) // reduceByKey @@ -181,7 +192,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("simple cogroup") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) sc = new SparkContext("local", "test", conf) val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) @@ -199,7 +210,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -249,7 +260,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -304,7 +315,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with many hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -329,7 +340,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -347,7 +358,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 65a71e5a83698..57dcb4ffabac1 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -25,6 +25,17 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ class ExternalSorterSuite extends FunSuite with LocalSparkContext { + private def createSparkConf(loadDefaults: Boolean): SparkConf = { + val conf = new SparkConf(loadDefaults) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf + } + test("empty data stream") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -60,7 +71,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("few elements per partition") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -102,7 +113,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("empty partitions with spilling") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -127,7 +138,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling in local cluster") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -198,7 +209,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling in local cluster with many reduce tasks") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[2,1,512]", "test", conf) @@ -269,7 +280,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in sorter") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -290,7 +301,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in sorter if there are errors") { - val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -311,7 +322,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -326,7 +337,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("cleanup of intermediate files in shuffle with errors") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -348,7 +359,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("no partial aggregation or sorting") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -363,7 +374,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation without spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -379,7 +390,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation with spill, no ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -395,7 +406,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("partial aggregation with spill, with ordering") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -412,7 +423,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("sorting without aggregation, no spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -429,7 +440,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("sorting without aggregation, with spill") { - val conf = new SparkConf(false) + val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -446,7 +457,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -503,7 +514,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with many hash collisions") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.0001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -526,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with hash collisions using the Int.MaxValue key") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) @@ -547,7 +558,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { } test("spilling with null keys and values") { - val conf = new SparkConf(true) + val conf = createSparkConf(true) conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) diff --git a/docs/configuration.md b/docs/configuration.md index 2a71d7b820e5f..870343f1c0bd2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -385,7 +385,7 @@ Apart from these, the following properties are also available, and may be useful When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches objects to prevent writing redundant data, however that stops garbage collection of those objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to a value <= 0. + objects to be collected. To turn off this periodic reset set it to -1. By default it will reset the serializer every 100 objects. From 9fd82dbbcb8b10debbe95f1acab53ae8b340f38e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Aug 2014 15:54:52 -0700 Subject: [PATCH 142/170] [SPARK-1687] [PySpark] fix unit tests related to pickable namedtuple serializer is imported multiple times during doctests, so it's better to make _hijack_namedtuple() safe to be called multiple times. Author: Davies Liu Closes #1771 from davies/fix and squashes the following commits: 1a9e336 [Davies Liu] fix unit tests --- python/pyspark/serializers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1b52c144df087..a10f85b55ad30 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -297,8 +297,11 @@ def __reduce__(self): def _hijack_namedtuple(): """ Hack namedtuple() to make it picklable """ - global _old_namedtuple # or it will put in closure + # hijack only one time + if hasattr(collections.namedtuple, "__hijack"): + return + global _old_namedtuple # or it will put in closure def _copy_func(f): return types.FunctionType(f.func_code, f.func_globals, f.func_name, f.func_defaults, f.func_closure) @@ -313,6 +316,7 @@ def namedtuple(name, fields, verbose=False, rename=False): collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, From 05bf4e4aff0d052a53d3e64c43688f07e27fec50 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Aug 2014 20:39:18 -0700 Subject: [PATCH 143/170] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext Author: Reynold Xin Closes #1772 from rxin/accumulator-dagscheduler and squashes the following commits: 6a58520 [Reynold Xin] [SPARK-2323] Exception in accumulator update should not crash DAGScheduler & SparkContext. --- .../org/apache/spark/scheduler/DAGScheduler.scala | 9 +++++++-- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 11 +++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d87c3048985fc..9fa3a4e9c71ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -904,8 +904,13 @@ class DAGScheduler( event.reason match { case Success => if (event.accumUpdates != null) { - // TODO: fail the stage if the accumulator update fails... - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + try { + Accumulators.add(event.accumUpdates) + } catch { + // If we see an exception during accumulator update, just log the error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } } stage.pendingTasks -= task task match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 36e238b4c9434..8c1b0fed11f72 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -622,8 +622,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } - // TODO: Fix this and un-ignore the test. - ignore("misbehaved accumulator should not crash DAGScheduler and SparkContext") { + test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { val acc = new Accumulator[Int](0, new AccumulatorParam[Int] { override def addAccumulator(t1: Int, t2: Int): Int = t1 + t2 override def zero(initialValue: Int): Int = 0 @@ -633,14 +632,10 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F }) // Run this on executors - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - } + sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } // Run this within a local thread - intercept[SparkDriverExecutionException] { - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - } + sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) // Make sure we can still run local commands as well as cluster commands. assert(sc.parallelize(1 to 10, 2).count() === 10) From 066765d60d21b6b9943862b788e4a4bd07396e6c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:27:53 -0700 Subject: [PATCH 144/170] SPARK-2685. Update ExternalAppendOnlyMap to avoid buffer.remove() Replaces this with an O(1) operation that does not have to shift over the whole tail of the array into the gap produced by the element removed. Author: Matei Zaharia Closes #1773 from mateiz/SPARK-2685 and squashes the following commits: 1ea028a [Matei Zaharia] Update comments in StreamBuffer and EAOM, and reuse ArrayBuffers eb1abfd [Matei Zaharia] Update ExternalAppendOnlyMap to avoid buffer.remove() --- .../collection/ExternalAppendOnlyMap.scala | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5d10a1f84493c..1f7d2dc838ebc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -286,30 +286,32 @@ class ExternalAppendOnlyMap[K, V, C]( private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => - val kcPairs = getMorePairs(it) + val kcPairs = new ArrayBuffer[(K, C)] + readNextHashCode(it, kcPairs) if (kcPairs.length > 0) { mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) } } /** - * Fetch from the given iterator until a key of different hash is retrieved. + * Fill a buffer with the next set of keys with the same hash code from a given iterator. We + * read streams one hash code at a time to ensure we don't miss elements when they are merged. + * + * Assumes the given iterator is in sorted order of hash code. * - * In the event of key hash collisions, this ensures no pairs are hidden from being merged. - * Assume the given iterator is in sorted order. + * @param it iterator to read from + * @param buf buffer to write the results into */ - private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { - val kcPairs = new ArrayBuffer[(K, C)] + private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = { if (it.hasNext) { var kc = it.next() - kcPairs += kc + buf += kc val minHash = hashKey(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() - kcPairs += kc + buf += kc } } - kcPairs } /** @@ -321,7 +323,9 @@ class ExternalAppendOnlyMap[K, V, C]( while (i < buffer.pairs.length) { val pair = buffer.pairs(i) if (pair._1 == key) { - buffer.pairs.remove(i) + // Note that there's at most one pair in the buffer with a given key, since we always + // merge stuff in a map before spilling, so it's safe to return after the first we find + removeFromBuffer(buffer.pairs, i) return mergeCombiners(baseCombiner, pair._2) } i += 1 @@ -329,6 +333,19 @@ class ExternalAppendOnlyMap[K, V, C]( baseCombiner } + /** + * Remove the index'th element from an ArrayBuffer in constant time, swapping another element + * into its place. This is more efficient than the ArrayBuffer.remove method because it does + * not have to shift all the elements in the array over. It works for our array buffers because + * we don't care about the order of elements inside, we just want to search them for a key. + */ + private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = { + val elem = buffer(index) + buffer(index) = buffer(buffer.size - 1) // This also works if index == buffer.size - 1 + buffer.reduceToSize(buffer.size - 1) + elem + } + /** * Return true if there exists an input stream that still has unvisited pairs. */ @@ -346,7 +363,7 @@ class ExternalAppendOnlyMap[K, V, C]( val minBuffer = mergeHeap.dequeue() val minPairs = minBuffer.pairs val minHash = minBuffer.minKeyHash - val minPair = minPairs.remove(0) + val minPair = removeFromBuffer(minPairs, 0) val minKey = minPair._1 var minCombiner = minPair._2 assert(hashKey(minPair) == minHash) @@ -363,7 +380,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Repopulate each visited stream buffer and add it back to the queue if it is non-empty mergedBuffers.foreach { buffer => if (buffer.isEmpty) { - buffer.pairs ++= getMorePairs(buffer.iterator) + readNextHashCode(buffer.iterator, buffer.pairs) } if (!buffer.isEmpty) { mergeHeap.enqueue(buffer) @@ -375,10 +392,13 @@ class ExternalAppendOnlyMap[K, V, C]( /** * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. - * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to - * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * Each buffer maintains all of the key-value pairs with what is currently the lowest hash + * code among keys in the stream. There may be multiple keys if there are hash collisions. + * Note that because when we spill data out, we only spill one value for each key, there is + * at most one element for each key. * - * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + * StreamBuffers are ordered by the minimum key hash currently available in their stream so + * that we can put them into a heap and sort that. */ private class StreamBuffer( val iterator: BufferedIterator[(K, C)], From 4fde28c2063f673ec7f51d514ba62a73321960a1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Aug 2014 23:41:03 -0700 Subject: [PATCH 145/170] SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere). Author: Matei Zaharia Closes #1707 from mateiz/spark-2711 and squashes the following commits: debf75b [Matei Zaharia] Review comments 24f28f3 [Matei Zaharia] Small rename c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests 315e3a5 [Matei Zaharia] Some review comments b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections --- .../scala/org/apache/spark/SparkEnv.scala | 10 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 125 ++++++++ .../collection/ExternalAppendOnlyMap.scala | 48 +-- .../util/collection/ExternalSorter.scala | 49 +-- .../shuffle/ShuffleMemoryManagerSuite.scala | 294 ++++++++++++++++++ 6 files changed, 450 insertions(+), 81 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0bce531aaba3e..dd8e4ac66dc66 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -66,12 +66,9 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { - // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations - // All accesses should be manually synchronized - val shuffleMemoryMap = mutable.HashMap[Long, Long]() - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -252,6 +249,8 @@ object SparkEnv extends Logging { val shuffleManager = instantiateClass[ShuffleManager]( "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -273,6 +272,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + shuffleMemoryManager, conf) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1bb1b4aae91bb..c2b9c660ddaec 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -276,10 +276,7 @@ private[spark] class Executor( } } finally { // Release memory used by this thread for shuffles - val shuffleMemoryMap = env.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap.remove(Thread.currentThread().getId) - } + env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala new file mode 100644 index 0000000000000..ee91a368b76ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkException, SparkConf} + +/** + * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory + * from this pool and release it as it spills data out. When a task ends, all its memory will be + * released by the Executor. + * + * This class tries to ensure that each thread gets a reasonable share of memory, instead of some + * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * this set changes. This is all done by synchronizing access on "this" to mutate state and using + * wait() and notifyAll() to signal changes. + */ +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { + private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + + def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + + /** + * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * obtained, or 0 if none can be allocated. This call may block until there is enough free memory + * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active threads) before it is forced to spill. This can + * happen if the number of threads increases but an older thread had a lot of memory already. + */ + def tryToAcquire(numBytes: Long): Long = synchronized { + val threadId = Thread.currentThread().getId + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this thread to the threadMemory map just so we can keep an accurate count of the number + // of active threads, to let other threads ramp down their memory in calls to tryToAcquire + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numThreads again + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // thread would have more than 1 / numActiveThreads of the memory) or we have enough free + // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + while (true) { + val numActiveThreads = threadMemory.keys.size + val curMem = threadMemory(threadId) + val freeMemory = maxMemory - threadMemory.values.sum + + // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads + val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem) + + if (curMem < maxMemory / (2 * numActiveThreads)) { + // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; + // if we can't give it this much now, wait for other threads to free up memory + // (this happens if older threads allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } else { + logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + wait() + } + } else { + // Only give it as much memory as is free, which might be none if it reached 1 / numThreads + val toGrant = math.min(maxToGrant, freeMemory) + threadMemory(threadId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** Release numBytes bytes for the current thread. */ + def release(numBytes: Long): Unit = synchronized { + val threadId = Thread.currentThread().getId + val curMem = threadMemory.getOrElse(threadId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + } + threadMemory(threadId) -= numBytes + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } + + /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisThread(): Unit = synchronized { + val threadId = Thread.currentThread().getId + threadMemory.remove(threadId) + notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + } +} + +private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their sizes again. + */ + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1f7d2dc838ebc..cc0423856cefb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows @@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && currentMap.estimateSize() >= myMemoryThreshold) { - val currentSize = currentMap.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // this map to grow and, if possible, allocate the required amount - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not synchronize spills - if (shouldSpill) { - spill(currentSize) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = currentMap.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory) // Will also release memory back to ShuffleMemoryManager } } currentMap.changeValue(curEntry._1, update) @@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } - myMemoryThreshold = 0 + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L elementsRead = 0 _memoryBytesSpilled += mapSize diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b04c50bd3e196..101c83b264f63 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - // Collective memory threshold shared across all running tasks - private val maxMemoryThreshold = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L @@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C]( if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && collection.estimateSize() >= myMemoryThreshold) { - // TODO: This logic doesn't work if there are two external collections being used in the same - // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711] - - val currentSize = collection.estimateSize() - var shouldSpill = false - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - - // Atomically check whether there is sufficient memory in the global pool for - // us to double our threshold - shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId - val previouslyClaimedMemory = shuffleMemoryMap.get(threadId) - val availableMemory = maxMemoryThreshold - - (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L)) - - // Try to allocate at least 2x more memory, otherwise spill - shouldSpill = availableMemory < currentSize * 2 - if (!shouldSpill) { - shuffleMemoryMap(threadId) = currentSize * 2 - myMemoryThreshold = currentSize * 2 - } - } - // Do not hold lock during spills - if (shouldSpill) { - spill(currentSize, usingMap) + // Claim up to double our current memory from the shuffle memory pool + val currentMemory = collection.estimateSize() + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager } } } @@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C]( buffer = new SizeTrackingPairBuffer[(Int, K), C] } - // Reset the amount of shuffle memory used by this map in the global pool - val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap - shuffleMemoryMap.synchronized { - shuffleMemoryMap(Thread.currentThread().getId) = 0 - } + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) myMemoryThreshold = 0 spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala new file mode 100644 index 0000000000000..d31bc22ee74f7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.CountDownLatch + +class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { + /** Launch a thread with the given body block and return it. */ + private def startThread(name: String)(body: => Unit): Thread = { + val thread = new Thread("ShuffleMemorySuite " + name) { + override def run() { + body + } + } + thread.start() + thread + } + + test("single thread requesting memory") { + val manager = new ShuffleMemoryManager(1000L) + + assert(manager.tryToAcquire(100L) === 100L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(400L) === 400L) + assert(manager.tryToAcquire(200L) === 100L) + assert(manager.tryToAcquire(100L) === 0L) + assert(manager.tryToAcquire(100L) === 0L) + + manager.release(500L) + assert(manager.tryToAcquire(300L) === 300L) + assert(manager.tryToAcquire(300L) === 200L) + + manager.releaseMemoryForThisThread() + assert(manager.tryToAcquire(1000L) === 1000L) + assert(manager.tryToAcquire(100L) === 0L) + } + + test("two threads requesting full memory") { + // Two threads request 500 bytes first, wait for each other to get it, and then request + // 500 more; we should immediately return 0 as both are now at 1 / N + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 500L) + assert(state.t2Result1 === 500L) + assert(state.t1Result2 === 0L) + assert(state.t2Result2 === 0L) + } + + + test("threads cannot grow past 1 / N") { + // Two threads request 250 bytes first, wait for each other to get it, and then request + // 500 more; we should only grant 250 bytes to each of them on this second request + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Result1 = -1L + var t2Result1 = -1L + var t1Result2 = -1L + var t2Result2 = -1L + } + val state = new State + + val t1 = startThread("t1") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t1Result1 = r1 + state.notifyAll() + while (state.t2Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t1Result2 = r2 } + } + + val t2 = startThread("t2") { + val r1 = manager.tryToAcquire(250L) + state.synchronized { + state.t2Result1 = r1 + state.notifyAll() + while (state.t1Result1 === -1L) { + state.wait() + } + } + val r2 = manager.tryToAcquire(500L) + state.synchronized { state.t2Result2 = r2 } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + assert(state.t1Result1 === 250L) + assert(state.t2Result1 === 250L) + assert(state.t1Result2 === 250L) + assert(state.t2Result2 === 250L) + } + + test("threads can block to get at least 1 / 2N memory") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests + // by t2 will return false right away because it now has 1 / 2N of the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result = -1L + var t2Result2 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.release(250L) + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val result = manager.tryToAcquire(250L) + val endTime = System.currentTimeMillis() + state.synchronized { + state.t2Result = result + // A second call should return 0 because we're now already at 1 / 2N + state.t2Result2 = manager.tryToAcquire(100L) + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released 250 + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result === 250L, "t2 could not allocate memory") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + assert(state.t2Result2 === 0L, "t1 got extra memory the second time") + } + } + + test("releaseMemoryForThisThread") { + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps + // for a bit and releases all its memory. t2 should now be able to grab all the memory. + + val manager = new ShuffleMemoryManager(1000L) + + class State { + var t1Requested = false + var t2Requested = false + var t1Result = -1L + var t2Result1 = -1L + var t2Result2 = -1L + var t2Result3 = -1L + var t2WaitTime = 0L + } + val state = new State + + val t1 = startThread("t1") { + state.synchronized { + state.t1Result = manager.tryToAcquire(1000L) + state.t1Requested = true + state.notifyAll() + while (!state.t2Requested) { + state.wait() + } + } + // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make + // sure the other thread blocks for some time otherwise + Thread.sleep(300) + manager.releaseMemoryForThisThread() + } + + val t2 = startThread("t2") { + state.synchronized { + while (!state.t1Requested) { + state.wait() + } + state.t2Requested = true + state.notifyAll() + } + val startTime = System.currentTimeMillis() + val r1 = manager.tryToAcquire(500L) + val endTime = System.currentTimeMillis() + val r2 = manager.tryToAcquire(500L) + val r3 = manager.tryToAcquire(500L) + state.synchronized { + state.t2Result1 = r1 + state.t2Result2 = r2 + state.t2Result3 = r3 + state.t2WaitTime = endTime - startTime + } + } + + failAfter(20 seconds) { + t1.join() + t2.join() + } + + // Both threads should've been able to acquire their memory; the second one will have waited + // until the first one acquired 1000 bytes and then released all of it + state.synchronized { + assert(state.t1Result === 1000L, "t1 could not allocate memory") + assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") + assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") + assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") + assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") + } + } +} From a646a365e3beb8d0cd7e492e625ce68ee9439a07 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Aug 2014 00:39:07 -0700 Subject: [PATCH 146/170] [SPARK-2857] Correct properties to set Master / Worker ports `master.ui.port` and `worker.ui.port` were never picked up by SparkConf, simply because they are not prefixed with "spark." Unfortunately, this is also currently the documented way of setting these values. Author: Andrew Or Closes #1779 from andrewor14/master-worker-port and squashes the following commits: 8475e95 [Andrew Or] Update docs to reflect changes in configs 4db3d5d [Andrew Or] Stop using configs that don't actually work --- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- docs/spark-standalone.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index a87781fb93850..4b0dbbe543d3f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -38,8 +38,8 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } - if (conf.contains("master.ui.port")) { - webUiPort = conf.get("master.ui.port").toInt + if (conf.contains("spark.master.ui.port")) { + webUiPort = conf.get("spark.master.ui.port").toInt } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 0ad2edba2227f..a9f531e9e4cae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -58,6 +58,6 @@ private[spark] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) + requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) } } diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fb30765f35e8..293a7ac9bc9aa 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -314,7 +314,7 @@ configure those ports. Standalone Cluster Master 8080 Web UI - master.ui.port + spark.master.ui.port Jetty-based @@ -338,7 +338,7 @@ configure those ports. Worker 8081 Web UI - worker.ui.port + spark.worker.ui.port Jetty-based From 9862c614c06507aa7624208f1d7ed5bc027ca52e Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 5 Aug 2014 00:51:07 -0700 Subject: [PATCH 147/170] [SPARK-1779] Throw an exception if memory fractions are not between 0 and 1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: wangfei Author: wangfei Closes #714 from scwf/memoryFraction and squashes the following commits: 6e385b9 [wangfei] Update SparkConf.scala da6ee59 [wangfei] add configs 829a195 [wangfei] add indent 717c0ca [wangfei] updated to make more concise fc45476 [wangfei] validate memoryfraction in sparkconf 2e79b3d [wangfei] && => || 43621bd [wangfei] && => || cf38bcf [wangfei] throw IllegalArgumentException 14d18ac [wangfei] throw IllegalArgumentException dff1f0f [wangfei] Update BlockManager.scala 764965f [wangfei] Update ExternalAppendOnlyMap.scala a59d76b [wangfei] Throw exception when memoryFracton is out of range 7b899c2 [wangfei] 【SPARK-1779】 --- .../main/scala/org/apache/spark/SparkConf.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 38700847c80f4..cce7a23d3b9fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -238,6 +238,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } + // Validate memory fractions + val memoryKeys = Seq( + "spark.storage.memoryFraction", + "spark.shuffle.memoryFraction", + "spark.shuffle.safetyFraction", + "spark.storage.unrollFraction", + "spark.storage.safetyFraction") + for (key <- memoryKeys) { + val value = getDouble(key, 0.5) + if (value > 1 || value < 0) { + throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + } + } + // Check for legacy configs sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = From 184048f80b6fa160c89d5bb47b937a0a89534a95 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 5 Aug 2014 01:30:46 -0700 Subject: [PATCH 148/170] [SPARK-2856] Decrease initial buffer size for Kryo to 64KB. Author: Reynold Xin Closes #1780 from rxin/kryo-init-size and squashes the following commits: 551b935 [Reynold Xin] [SPARK-2856] Decrease initial buffer size for Kryo to 64KB. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 4 +++- docs/configuration.md | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e60b802a86a14..407cb9db6ee9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -47,7 +47,9 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val bufferSize = + (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) diff --git a/docs/configuration.md b/docs/configuration.md index 870343f1c0bd2..b3dee3f131411 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -412,7 +412,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.mb - 2 + 0.064 Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer per core on each worker. This buffer will grow up to From e87075df977a539e4a1684045a7bd66c36285174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 5 Aug 2014 10:40:28 -0700 Subject: [PATCH 149/170] [SPARK-1022][Streaming] Add Kafka real unit test This PR is a updated version of (https://github.com/apache/spark/pull/557) to actually test sending and receiving data through Kafka, and fix previous flaky issues. @tdas, would you mind reviewing this PR? Thanks a lot. Author: jerryshao Closes #1751 from jerryshao/kafka-unit-test and squashes the following commits: b6a505f [jerryshao] code refactor according to comments 5222330 [jerryshao] Change JavaKafkaStreamSuite to better test it 5525f10 [jerryshao] Fix flaky issue of Kafka real unit test 4559310 [jerryshao] Minor changes for Kafka unit test 860f649 [jerryshao] Minor style changes, and tests ignored due to flakiness 796d4ca [jerryshao] Add real Kafka streaming test --- external/kafka/pom.xml | 6 + .../streaming/kafka/JavaKafkaStreamSuite.java | 125 +++++++++-- .../streaming/kafka/KafkaStreamSuite.scala | 197 ++++++++++++++++-- 3 files changed, 293 insertions(+), 35 deletions(-) diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index daf03360bc5f5..2aee99949223a 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -70,6 +70,12 @@ + + net.sf.jopt-simple + jopt-simple + 3.2 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 9f8046bf00f8f..0571454c01dae 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -17,31 +17,118 @@ package org.apache.spark.streaming.kafka; +import java.io.Serializable; import java.util.HashMap; +import java.util.List; + +import scala.Predef; +import scala.Tuple2; +import scala.collection.JavaConverters; + +import junit.framework.Assert; -import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; -import org.junit.Test; -import com.google.common.collect.Maps; import kafka.serializer.StringDecoder; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import org.junit.Test; +import org.junit.After; +import org.junit.Before; + +public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { + private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); + + @Before + @Override + public void setUp() { + testSuite.beforeFunction(); + System.clearProperty("spark.driver.port"); + //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + } + + @After + @Override + public void tearDown() { + ssc.stop(); + ssc = null; + System.clearProperty("spark.driver.port"); + testSuite.afterFunction(); + } -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext { @Test - public void testKafkaStream() { - HashMap topics = Maps.newHashMap(); - - // tests the API, does not actually test data receiving - JavaPairReceiverInputDStream test1 = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics); - JavaPairReceiverInputDStream test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, - StorageLevel.MEMORY_AND_DISK_SER_2()); - - HashMap kafkaParams = Maps.newHashMap(); - kafkaParams.put("zookeeper.connect", "localhost:12345"); - kafkaParams.put("group.id","consumer-group"); - JavaPairReceiverInputDStream test3 = KafkaUtils.createStream(ssc, - String.class, String.class, StringDecoder.class, StringDecoder.class, - kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2()); + public void testKafkaStream() throws InterruptedException { + String topic = "topic1"; + HashMap topics = new HashMap(); + topics.put(topic, 1); + + HashMap sent = new HashMap(); + sent.put("a", 5); + sent.put("b", 3); + sent.put("c", 10); + + testSuite.createTopic(topic); + HashMap tmp = new HashMap(sent); + testSuite.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaPairDStream stream = KafkaUtils.createStream(ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topics, + StorageLevel.MEMORY_ONLY_SER()); + + final HashMap result = new HashMap(); + + JavaDStream words = stream.map( + new Function, String>() { + @Override + public String call(Tuple2 tuple2) throws Exception { + return tuple2._2(); + } + } + ); + + words.countByValue().foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + List> ret = rdd.collect(); + for (Tuple2 r : ret) { + if (result.containsKey(r._1())) { + result.put(r._1(), result.get(r._1()) + r._2()); + } else { + result.put(r._1(), r._2()); + } + } + + return null; + } + } + ); + + ssc.start(); + ssc.awaitTermination(3000); + + Assert.assertEquals(sent.size(), result.size()); + for (String k : sent.keySet()) { + Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e6f2c4a5cf5d1..c0b55e9340253 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,28 +17,193 @@ package org.apache.spark.streaming.kafka -import kafka.serializer.StringDecoder +import java.io.File +import java.net.InetSocketAddress +import java.util.{Properties, Random} + +import scala.collection.mutable + +import kafka.admin.CreateTopicCommand +import kafka.common.TopicAndPartition +import kafka.producer.{KeyedMessage, ProducerConfig, Producer} +import kafka.utils.ZKStringSerializer +import kafka.serializer.{StringDecoder, StringEncoder} +import kafka.server.{KafkaConfig, KafkaServer} + +import org.I0Itec.zkclient.ZkClient + +import org.apache.zookeeper.server.ZooKeeperServer +import org.apache.zookeeper.server.NIOServerCnxnFactory + import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { + import KafkaTestUtils._ + + val zkConnect = "localhost:2181" + val zkConnectionTimeout = 6000 + val zkSessionTimeout = 6000 + + val brokerPort = 9092 + val brokerProps = getBrokerConfig(brokerPort, zkConnect) + val brokerConf = new KafkaConfig(brokerProps) + + protected var zookeeper: EmbeddedZookeeper = _ + protected var zkClient: ZkClient = _ + protected var server: KafkaServer = _ + protected var producer: Producer[String, String] = _ + + override def useManualClock = false + + override def beforeFunction() { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(zkConnect) + logInfo("==================== 0 ====================") + zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + logInfo("==================== 1 ====================") - test("kafka input stream") { + // Kafka broker startup + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + Thread.sleep(2000) + logInfo("==================== 4 ====================") + super.beforeFunction() + } + + override def afterFunction() { + producer.close() + server.shutdown() + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + zkClient.close() + zookeeper.shutdown() + + super.afterFunction() + } + + test("Kafka input stream") { val ssc = new StreamingContext(master, framework, batchDuration) - val topics = Map("my-topic" -> 1) - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:1234", "group", topics) - val test2: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2) - val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group") - val test3: ReceiverInputDStream[(String, String)] = - KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2) - - // TODO: Actually test receiving data + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkConnect, + "group.id" -> s"test-consumer-${random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, + kafkaParams, + Map(topic -> 1), + StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v } + .countByValue() + .foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + ssc.awaitTermination(3000) + + assert(sent.size === result.size) + sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + ssc.stop() } + + private def createTestMessage(topic: String, sent: Map[String, Int]) + : Seq[KeyedMessage[String, String]] = { + val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { + new KeyedMessage[String, String](topic, s) + } + messages.toSeq + } + + def createTopic(topic: String) { + CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") + logInfo("==================== 5 ====================") + // wait until metadata is propagated + waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + } + + def produceAndSendMessage(topic: String, sent: Map[String, Int]) { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port + producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer.send(createTestMessage(topic, sent): _*) + logInfo("==================== 6 ====================") + } +} + +object KafkaTestUtils { + val random = new Random() + + def getBrokerConfig(port: Int, zkConnect: String): Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", port.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkConnect) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + def getProducerConfig(brokerList: String): Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerList) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { + val startTime = System.currentTimeMillis() + while (true) { + if (condition()) + return true + if (System.currentTimeMillis() > startTime + waitTime) + return false + Thread.sleep(waitTime.min(100L)) + } + // Should never go to here + throw new RuntimeException("unexpected error") + } + + def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, + timeout: Long) { + assert(waitUntilTrue(() => + servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( + TopicAndPartition(topic, partition))), timeout), + s"Partition [$topic, $partition] metadata not propagated after timeout") + } + + class EmbeddedZookeeper(val zkConnect: String) { + val random = new Random() + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } } From 2c0f705e26ca3dfc43a1e9a0722c0e57f67c970a Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:48:26 -0500 Subject: [PATCH 150/170] SPARK-1528 - spark on yarn, add support for accessing remote HDFS Add a config (spark.yarn.access.namenodes) to allow applications running on yarn to access other secure HDFS cluster. User just specifies the namenodes of the other clusters and we get Tokens for those and ship them with the spark application. Author: Thomas Graves Closes #1159 from tgravescs/spark-1528 and squashes the following commits: ddbcd16 [Thomas Graves] review comments 0ac8501 [Thomas Graves] SPARK-1528 - add support for accessing remote HDFS --- docs/running-on-yarn.md | 7 +++ .../apache/spark/deploy/yarn/ClientBase.scala | 56 +++++++++++++------ .../spark/deploy/yarn/ClientBaseSuite.scala | 55 +++++++++++++++++- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0362f5a223319..573930dbf4e54 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -106,6 +106,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes set this configuration to "hdfs:///some/path". + + spark.yarn.access.namenodes + (none) + + A list of secure HDFS namenodes your Spark application is going to access. For example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. The Spark application must have acess to the namenodes listed and Kerberos must be properly configured to be able to access them (either in the same realm or in a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. + + # Launching Spark on YARN diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index b7e8636e02eb2..ed8f56ab8b75e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -191,23 +191,11 @@ trait ClientBase extends Logging { // Upload Spark and the application JAR to the remote file system if necessary. Add them as // local resources to the application master. val fs = FileSystem.get(conf) - - val delegTokenRenewer = Master.getMasterPrincipal(conf) - if (UserGroupInformation.isSecurityEnabled()) { - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort - - if (UserGroupInformation.isSecurityEnabled()) { - val dstFs = dst.getFileSystem(conf) - dstFs.addDelegationTokens(delegTokenRenewer, credentials) - } + val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst + ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -614,4 +602,40 @@ object ClientBase extends Logging { YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) + /** + * Get the list of namenodes the user may access. + */ + private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) + .map(new Path(_)).toSet + } + + private[yarn] def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, + creds: Credentials) { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = getTokenRenewer(conf) + + paths.foreach { + dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 686714dc36488..68cc2890f3a22 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -31,6 +31,8 @@ import org.apache.hadoop.yarn.api.records.ContainerLaunchContext import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ + + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -38,7 +40,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } import scala.util.Try -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils class ClientBaseSuite extends FunSuite with Matchers { @@ -138,6 +140,57 @@ class ClientBaseSuite extends FunSuite with Matchers { } } + test("check access nns empty") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val nns = ClientBase.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val renewer = ClientBase.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val caught = + intercept[SparkException] { + ClientBase.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 1c5555a23d3aa40423d658cfbf2c956ad415a6b1 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 5 Aug 2014 12:52:52 -0500 Subject: [PATCH 151/170] SPARK-1890 and SPARK-1891- add admin and modify acls It was easier to combine these 2 jira since they touch many of the same places. This pr adds the following: - adds modify acls - adds admin acls (list of admins/users that get added to both view and modify acls) - modify Kill button on UI to take modify acls into account - changes config name of spark.ui.acls.enable to spark.acls.enable since I choose poorly in original name. We keep backwards compatibility so people can still use spark.ui.acls.enable. The acls should apply to any web ui as well as any CLI interfaces. - send view and modify acls information on to YARN so that YARN interfaces can use (yarn cli for killing applications for example). Author: Thomas Graves Closes #1196 from tgravescs/SPARK-1890 and squashes the following commits: 8292eb1 [Thomas Graves] review comments b92ec89 [Thomas Graves] remove unneeded variable from applistener 4c765f4 [Thomas Graves] Add in admin acls 72eb0ac [Thomas Graves] Add modify acls --- .../org/apache/spark/SecurityManager.scala | 107 +++++++++++++++--- .../deploy/history/FsHistoryProvider.scala | 4 +- .../scheduler/ApplicationEventListener.scala | 4 +- .../apache/spark/ui/jobs/JobProgressTab.scala | 2 +- .../apache/spark/SecurityManagerSuite.scala | 83 ++++++++++++-- docs/configuration.md | 27 ++++- docs/security.md | 7 +- .../apache/spark/deploy/yarn/ClientBase.scala | 9 +- 8 files changed, 206 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 74aa441619bd2..25c2c9fc6af7c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -41,10 +41,19 @@ import org.apache.spark.deploy.SparkHadoopUtil * secure the UI if it has data that other users should not be allowed to see. The javax * servlet filter specified by the user can authenticate the user and then once the user * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * authorized to view the UI. The configs 'spark.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * + * Spark has a set of modify acls (`spark.modify.acls`) that controls which users have permission + * to modify a single application. This would include things like killing the application. By + * default the person who started the application has modify access. For modify access through + * the UI, you must have a filter that does authentication in place for the modify acls to work + * properly. + * + * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators + * who always have permission to view or modify the Spark application. + * * Spark does not currently support encryption after authentication. * * At this point spark has multiple communication protocols that need to be secured and @@ -137,18 +146,32 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { private val sparkSecretLookupKey = "sparkCookie" private val authOn = sparkConf.getBoolean("spark.authenticate", false) - private var uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + // keep spark.ui.acls.enable for backwards compatibility with 1.0 + private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse( + sparkConf.get("spark.ui.acls.enable", "false")).toBoolean + + // admin acls should be set before view or modify acls + private var adminAcls: Set[String] = + stringToSet(sparkConf.get("spark.admin.acls", "")) private var viewAcls: Set[String] = _ + + // list of users who have permission to modify the application. This should + // apply to both UI and CLI for things like killing the application. + private var modifyAcls: Set[String] = _ + // always add the current user and SPARK_USER to the viewAcls - private val defaultAclUsers = Seq[String](System.getProperty("user.name", ""), + private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), Option(System.getenv("SPARK_USER")).getOrElse("")) + setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) + setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + - "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + - "; users with view permissions: " + viewAcls.toString()) + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString() + + "; users with modify permissions: " + modifyAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -169,18 +192,51 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { ) } - private[spark] def setViewAcls(defaultUsers: Seq[String], allowedUsers: String) { - viewAcls = (defaultUsers ++ allowedUsers.split(',')).map(_.trim()).filter(!_.isEmpty).toSet + /** + * Split a comma separated String, filter out any empty items, and return a Set of strings + */ + private def stringToSet(list: String): Set[String] = { + list.split(',').map(_.trim).filter(!_.isEmpty).toSet + } + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setViewAcls(defaultUsers: Set[String], allowedUsers: String) { + viewAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) logInfo("Changing view acls to: " + viewAcls.mkString(",")) } - private[spark] def setViewAcls(defaultUser: String, allowedUsers: String) { - setViewAcls(Seq[String](defaultUser), allowedUsers) + def setViewAcls(defaultUser: String, allowedUsers: String) { + setViewAcls(Set[String](defaultUser), allowedUsers) + } + + def getViewAcls: String = viewAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setModifyAcls(defaultUsers: Set[String], allowedUsers: String) { + modifyAcls = (adminAcls ++ defaultUsers ++ stringToSet(allowedUsers)) + logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) + } + + def getModifyAcls: String = modifyAcls.mkString(",") + + /** + * Admin acls should be set before the view or modify acls. If you modify the admin + * acls you should also set the view and modify acls again to pick up the changes. + */ + def setAdminAcls(adminUsers: String) { + adminAcls = stringToSet(adminUsers) + logInfo("Changing admin acls to: " + adminAcls.mkString(",")) } - private[spark] def setUIAcls(aclSetting: Boolean) { - uiAclsOn = aclSetting - logInfo("Changing acls enabled to: " + uiAclsOn) + def setAcls(aclSetting: Boolean) { + aclsOn = aclSetting + logInfo("Changing acls enabled to: " + aclsOn) } /** @@ -224,22 +280,39 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * Check to see if Acls for the UI are enabled * @return true if UI authentication is enabled, otherwise false */ - def uiAclsEnabled(): Boolean = uiAclsOn + def aclsEnabled(): Boolean = aclsOn /** * Checks the given user against the view acl list to see if they have - * authorization to view the UI. If the UI acls must are disabled - * via spark.ui.acls.enable, all users have view access. + * authorization to view the UI. If the UI acls are disabled + * via spark.acls.enable, all users have view access. If the user is null + * it is assumed authentication is off and all users have access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { - logDebug("user=" + user + " uiAclsEnabled=" + uiAclsEnabled() + " viewAcls=" + + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + if (aclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true } + /** + * Checks the given user against the modify acl list to see if they have + * authorization to modify the application. If the UI acls are disabled + * via spark.acls.enable, all users have modify access. If the user is null + * it is assumed authentication isn't turned on and all users have access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkModifyPermissions(user: String): Boolean = { + logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + + modifyAcls.mkString(",")) + if (aclsEnabled() && (user != null) && (!modifyAcls.contains(user))) false else true + } + + /** * Check to see if authentication for the Spark communication protocols is enabled * @return true if authentication is enabled, otherwise false diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 6d2d4cef1ee46..cc06540ee0647 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -189,7 +189,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (ui != null) { val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setUIAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls) ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) } (appInfo, ui) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index cd5d44ad4a7e6..162158babc35b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -29,7 +29,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var startTime = -1L var endTime = -1L var viewAcls = "" - var enableViewAcls = false + var adminAcls = "" def applicationStarted = startTime != -1 @@ -55,7 +55,7 @@ private[spark] class ApplicationEventListener extends SparkListener { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - enableViewAcls = allProperties.getOrElse("spark.ui.acls.enable", "false").toBoolean + adminAcls = allProperties.getOrElse("spark.admin.acls", "") } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 3308c8c8a3d37..8a01ec80c9dd6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -41,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) def handleKillRequest(request: HttpServletRequest) = { - if (killEnabled) { + if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e39093e24d68a..fcca0867b8072 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -31,7 +31,7 @@ class SecurityManagerSuite extends FunSuite { conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); assert(securityManager.isAuthenticationEnabled() === true) - assert(securityManager.uiAclsEnabled() === true) + assert(securityManager.aclsEnabled() === true) assert(securityManager.checkUIViewPermissions("user1") === true) assert(securityManager.checkUIViewPermissions("user2") === true) assert(securityManager.checkUIViewPermissions("user3") === false) @@ -41,17 +41,17 @@ class SecurityManagerSuite extends FunSuite { val conf = new SparkConf conf.set("spark.ui.view.acls", "user1,user2") val securityManager = new SecurityManager(conf); - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setUIAcls(false) - assert(securityManager.uiAclsEnabled() === false) + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) // acls are off so doesn't matter what view acls set to assert(securityManager.checkUIViewPermissions("user4") === true) - securityManager.setUIAcls(true) - assert(securityManager.uiAclsEnabled() === true) - securityManager.setViewAcls(ArrayBuffer[String]("user5"), "user6,user7") + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setViewAcls(Set[String]("user5"), "user6,user7") assert(securityManager.checkUIViewPermissions("user1") === false) assert(securityManager.checkUIViewPermissions("user5") === true) assert(securityManager.checkUIViewPermissions("user6") === true) @@ -59,5 +59,72 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.checkUIViewPermissions("user8") === false) assert(securityManager.checkUIViewPermissions(null) === true) } + + test("set security modify acls") { + val conf = new SparkConf + conf.set("spark.modify.acls", "user1,user2") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setAcls(false) + assert(securityManager.aclsEnabled() === false) + + // acls are off so doesn't matter what view acls set to + assert(securityManager.checkModifyPermissions("user4") === true) + + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + securityManager.setModifyAcls(Set("user5"), "user6,user7") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user5") === true) + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === false) + assert(securityManager.checkModifyPermissions(null) === true) + } + + test("set security admin acls") { + val conf = new SparkConf + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "user3") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user3") === false) + assert(securityManager.checkModifyPermissions("user5") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user2") === true) + assert(securityManager.checkUIViewPermissions("user3") === true) + assert(securityManager.checkUIViewPermissions("user4") === false) + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + securityManager.setAdminAcls("user6") + securityManager.setViewAcls(Set[String]("user8"), "user9") + securityManager.setModifyAcls(Set("user11"), "user9") + assert(securityManager.checkModifyPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user11") === true) + assert(securityManager.checkModifyPermissions("user9") === true) + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user4") === false) + assert(securityManager.checkModifyPermissions(null) === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkUIViewPermissions("user8") === true) + assert(securityManager.checkUIViewPermissions("user9") === true) + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user3") === false) + assert(securityManager.checkUIViewPermissions(null) === true) + + } + + } diff --git a/docs/configuration.md b/docs/configuration.md index b3dee3f131411..25adea210cba0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -815,13 +815,13 @@ Apart from these, the following properties are also available, and may be useful - spark.ui.acls.enable + spark.acls.enable false - Whether Spark web ui acls should are enabled. If enabled, this checks to see if the user has - access permissions to view the web ui. See spark.ui.view.acls for more details. - Also note this requires the user to be known, if the user comes across as null no checks - are done. Filters can be used to authenticate and set the user. + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. @@ -832,6 +832,23 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. + + spark.modify.acls + Empty + + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example). + + + + spark.admin.acls + Empty + + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. + + #### Spark Streaming diff --git a/docs/security.md b/docs/security.md index 90ba678033b19..8312f8d017e1f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -8,8 +8,11 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.ui.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. -On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. + +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. + +Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index ed8f56ab8b75e..44e025b8f60ba 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The @@ -405,6 +405,13 @@ trait ClientBase extends Logging { amContainer.setCommands(printableCommands) setupSecurityToken(amContainer) + + // send the acl settings into YARN to control who has access via YARN interfaces + val securityManager = new SecurityManager(sparkConf) + val acls = Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) + amContainer.setApplicationACLs(acls) amContainer } } From 6e821e3d1ae1ed23459bc7f1098510b968130152 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 5 Aug 2014 11:17:50 -0700 Subject: [PATCH 152/170] [SPARK-2860][SQL] Fix coercion of CASE WHEN. Author: Michael Armbrust Closes #1785 from marmbrus/caseNull and squashes the following commits: 126006d [Michael Armbrust] better error message 2fe357f [Michael Armbrust] Fix coercion of CASE WHEN. --- .../catalyst/analysis/HiveTypeCoercion.scala | 56 +++++++++++-------- ...ll case-0-581cdfe70091e546414b202da2cebdcb | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 3 + 3 files changed, 36 insertions(+), 24 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e94f2a3bea63e..15eb5982a4a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -49,10 +49,21 @@ trait HiveTypeCoercion { BooleanCasts :: StringToIntegralCasts :: FunctionArgumentConversion :: - CastNulls :: + CaseWhenCoercion :: Division :: Nil + trait TypeWidening { + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -133,16 +144,7 @@ trait HiveTypeCoercion { * - LongType to FloatType * - LongType to DoubleType */ - object WidenTypes extends Rule[LogicalPlan] { - - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) - } + object WidenTypes extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -336,28 +338,34 @@ trait HiveTypeCoercion { } /** - * Ensures that NullType gets casted to some other types under certain circumstances. + * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CastNulls extends Rule[LogicalPlan] { + object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) => + case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) if value.resolved => Some(value.dataType) - case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) - case _ => None + case Seq(_, value) => value.dataType + case Seq(elseVal) => elseVal.dataType }.toSeq - if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { - val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + + logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") + + if (valueTypes.distinct.size > 1) { + val commonType = valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2) + .getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.resolved && value.dataType == NullType => - Seq(cond, Cast(value, otherType)) - case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => - Seq(Cast(elseVal, otherType)) + case Seq(cond, value) if value.dataType != commonType => + Seq(cond, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) case s => s }.reduce(_ ++ _) CaseWhen(transformedBranches) } else { - // It is possible to have more types due to the possibility of short-circuiting. + // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. cw } } diff --git a/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aa810a291231a..2f0be49b6a6d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -32,6 +32,9 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("null case", + "SELECT case when(true) then 1 else null end FROM src LIMIT 1") + createQueryTest("single case", """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") From ac3440f4f3c4b79070ffec7db0b08ad062b4df90 Mon Sep 17 00:00:00 2001 From: "Guancheng (G.C.) Chen" Date: Tue, 5 Aug 2014 11:50:08 -0700 Subject: [PATCH 153/170] [SPARK-2859] Update url of Kryo project in related docs JIRA Issue: https://issues.apache.org/jira/browse/SPARK-2859 Kryo project has been migrated from googlecode to github, hence we need to update its URL in related docs such as tuning.md. Author: Guancheng (G.C.) Chen Closes #1782 from gchen/kryo-docs and squashes the following commits: b62543c [Guancheng (G.C.) Chen] update url of Kryo project --- docs/tuning.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tuning.md b/docs/tuning.md index 4917c11bc1147..8fb2a0433b1a8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -32,7 +32,7 @@ in your operations) and performance. It provides two serialization libraries: [`java.io.Externalizable`](http://docs.oracle.com/javase/6/docs/api/java/io/Externalizable.html). Java serialization is flexible but often quite slow, and leads to large serialized formats for many classes. -* [Kryo serialization](http://code.google.com/p/kryo/): Spark can also use +* [Kryo serialization](https://github.com/EsotericSoftware/kryo): Spark can also use the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly faster and more compact than Java serialization (often as much as 10x), but does not support all `Serializable` types and requires you to *register* the classes you'll use in the program in advance @@ -68,7 +68,7 @@ conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") val sc = new SparkContext(conf) {% endhighlight %} -The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced +The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` From 74f82c71b03d265a7d0c98ce196ca8c44de002e8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 5 Aug 2014 13:08:23 -0700 Subject: [PATCH 154/170] SPARK-2380: Support displaying accumulator values in the web UI This patch adds support for giving accumulators user-visible names and displaying accumulator values in the web UI. This allows users to create custom counters that can display in the UI. The current approach displays both the accumulator deltas caused by each task and a "current" value of the accumulator totals for each stage, which gets update as tasks finish. Currently in Spark developers have been extending the `TaskMetrics` functionality to provide custom instrumentation for RDD's. This provides a potentially nicer alternative of going through the existing accumulator framework (actually `TaskMetrics` and accumulators are on an awkward collision course as we add more features to the former). The current patch demo's how we can use the feature to provide instrumentation for RDD input sizes. The nice thing about going through accumulators is that users can read the current value of the data being tracked in their programs. This could be useful to e.g. decide to short-circuit a Spark stage depending on how things are going. ![counters](https://cloud.githubusercontent.com/assets/320616/3488815/6ee7bc34-0505-11e4-84ce-e36d9886e2cf.png) Author: Patrick Wendell Closes #1309 from pwendell/metrics and squashes the following commits: 8815308 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into HEAD 93fbe0f [Patrick Wendell] Other minor fixes cc43f68 [Patrick Wendell] Updating unit tests c991b1b [Patrick Wendell] Moving some code into the Accumulators class 9a9ba3c [Patrick Wendell] More merge fixes c5ace9e [Patrick Wendell] More merge conflicts 1da15e3 [Patrick Wendell] Merge remote-tracking branch 'apache/master' into metrics 9860c55 [Patrick Wendell] Potential solution to posting listener events 0bb0e33 [Patrick Wendell] Remove "display" variable and assume display = name.isDefined 0ec4ac7 [Patrick Wendell] Java API's e95bf69 [Patrick Wendell] Stash be97261 [Patrick Wendell] Style fix 8407308 [Patrick Wendell] Removing examples in Hadoop and RDD class 64d405f [Patrick Wendell] Adding missing file 5d8b156 [Patrick Wendell] Changes based on Kay's review. 9f18bad [Patrick Wendell] Minor style changes and tests 7a63abc [Patrick Wendell] Adding Json serialization and responding to Reynold's feedback ad85076 [Patrick Wendell] Example of using named accumulators for custom RDD metrics. 0b72660 [Patrick Wendell] Initial WIP example of supporing globally named accumulators. --- .../scala/org/apache/spark/Accumulators.scala | 19 ++++-- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++ .../spark/api/java/JavaSparkContext.scala | 59 ++++++++++++++++++ .../spark/scheduler/AccumulableInfo.scala | 46 ++++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 24 ++++++- .../apache/spark/scheduler/StageInfo.scala | 4 ++ .../org/apache/spark/scheduler/TaskInfo.scala | 9 +++ .../spark/ui/jobs/JobProgressListener.scala | 10 ++- .../org/apache/spark/ui/jobs/StagePage.scala | 21 ++++++- .../org/apache/spark/ui/jobs/UIData.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 39 +++++++++++- .../apache/spark/util/JsonProtocolSuite.scala | 62 +++++++++++++++---- docs/programming-guide.md | 6 +- 13 files changed, 294 insertions(+), 27 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 9c55bfbb47626..12f2fe031cb1d 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -36,15 +36,21 @@ import org.apache.spark.serializer.JavaSerializer * * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` + * @param name human-readable name for use in Spark's web UI * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ class Accumulable[R, T] ( @transient initialValue: R, - param: AccumulableParam[R, T]) + param: AccumulableParam[R, T], + val name: Option[String]) extends Serializable { - val id = Accumulators.newId + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = + this(initialValue, param, None) + + val id: Long = Accumulators.newId + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false @@ -219,8 +225,10 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) + extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) +} /** * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add @@ -281,4 +289,7 @@ private object Accumulators { } } } + + def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) + def stringifyValue(value: Any) = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ba21cfcde01a..e132955f0f850 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -760,6 +760,15 @@ class SparkContext(config: SparkConf) extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display + * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the + * driver can access the accumulator's `value`. + */ + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + new Accumulator(initialValue, param, Some(name)) + } + /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. @@ -769,6 +778,16 @@ class SparkContext(config: SparkConf) extends Logging { def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the + * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can + * access the accumuable's `value`. + * @tparam T accumulator type + * @tparam R type that can be added to the accumulator + */ + def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + new Accumulable(initialValue, param, Some(name)) + /** * Create an accumulator from a "mutable collection" type. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d9d1c5955ca99..e0a4815940db3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -429,6 +429,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue, name)(IntAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Integer]] + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -436,12 +446,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + sc.accumulator(initialValue, name)(DoubleAccumulatorParam) + .asInstanceOf[Accumulator[java.lang.Double]] + /** * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + /** + * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = + intAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. @@ -449,6 +478,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator(initialValue: Double): Accumulator[java.lang.Double] = doubleAccumulator(initialValue) + + /** + * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue, name) + /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. @@ -456,6 +495,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `add` method. Only the master can access the accumulator's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) + : Accumulator[T] = + sc.accumulator(initialValue, name)(accumulatorParam) + /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks * can "add" values with `add`. Only the master can access the accumuable's `value`. @@ -463,6 +512,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) + /** + * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks + * can "add" values with `add`. Only the master can access the accumuable's `value`. + * + * This version supports naming the accumulator for display in Spark's web UI. + */ + def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) + : Accumulable[T, R] = + sc.accumulable(initialValue, name)(param) + /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala new file mode 100644 index 0000000000000..fa83372bb4d11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + */ +@DeveloperApi +class AccumulableInfo ( + val id: Long, + val name: String, + val update: Option[String], // represents a partial update within a task + val value: String) { + + override def equals(other: Any): Boolean = other match { + case acc: AccumulableInfo => + this.id == acc.id && this.name == acc.name && + this.update == acc.update && this.value == acc.value + case _ => false + } +} + +object AccumulableInfo { + def apply(id: Long, name: String, update: Option[String], value: String) = + new AccumulableInfo(id, name, update, value) + + def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9fa3a4e9c71ae..430e45ada5808 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -883,8 +883,14 @@ class DAGScheduler( val task = event.task val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, - event.taskMetrics)) + + // The success case is dealt with separately below, since we need to compute accumulator + // updates before posting. + if (event.reason != Success) { + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) + } + if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. return @@ -906,12 +912,26 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } } catch { // If we see an exception during accumulator update, just log the error and move on. case e: Exception => logError(s"Failed to update accumulators for $task", e) } } + listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo, + event.taskMetrics)) stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 480891550eb60..2a407e47a05bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -37,6 +39,8 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None + /** Terminal values of accumulables updated during this stage. */ + val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { failureReason = Some(reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index ca0595f35143e..6fa1f2c880f7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.ListBuffer + import org.apache.spark.annotation.DeveloperApi /** @@ -41,6 +43,13 @@ class TaskInfo( */ var gettingResultTime: Long = 0 + /** + * Intermediate updates to accumulables during this task. Note that it is valid for the same + * accumulable to be updated multiple times in a single task or for two accumulables with the + * same name but different IDs to exist in a task. + */ + val accumulables = ListBuffer[AccumulableInfo]() + /** * The time when the task has completed successfully (including the time to remotely fetch * results, if necessary). diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index da2f5d3172fe2..a57a354620163 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, ListBuffer, Map} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -65,6 +65,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for ((id, info) <- stageCompleted.stageInfo.accumulables) { + stageData.accumulables(id) = info + } + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { @@ -130,6 +134,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) + for (accumulableInfo <- info.accumulables) { + stageData.accumulables(accumulableInfo.id) = accumulableInfo + } + val execSummaryMap = stageData.executorSummary val execSummary = execSummaryMap.getOrElseUpdate(info.executorId, new ExecutorSummary) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cab26b9e2f7d3..8bc1ba758cf77 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,11 +20,12 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -51,6 +52,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) + val accumulables = listener.stageIdToData(stageId).accumulables val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 @@ -95,10 +97,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { // scalastyle:on + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") + def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, + accumulables.values.toSeq) + val taskHeaders: Seq[String] = Seq( "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", - "Launch Time", "Duration", "GC Time") ++ + "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ @@ -208,11 +215,16 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, parent) + + val maybeAccumulableTable: Seq[Node] = + if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + val content = summary ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ + maybeAccumulableTable ++

    Tasks

    ++ taskTable UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), @@ -279,6 +291,11 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} + + {Unparsed( + info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
    ") + )} + - - Browser - Standalone Cluster Master - 8080 - Web UI - spark.master.ui.port - Jetty-based - - - Browser - Driver - 4040 - Web UI - spark.ui.port - Jetty-based - - - Browser - History Server - 18080 - Web UI - spark.history.ui.port - Jetty-based - - - Browser - Worker - 8081 - Web UI - spark.worker.ui.port - Jetty-based - - - - Application - Standalone Cluster Master - 7077 - Submit job to cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Worker - Standalone Cluster Master - 7077 - Join cluster - spark.driver.port - Akka-based. Set to "0" to choose a port randomly - - - Application - Worker - (random) - Join cluster - SPARK_WORKER_PORT (standalone cluster) - Akka-based - - - - - Driver and other Workers - Worker - (random) - -
      -
    • File server for file and jars
    • -
    • Http Broadcast
    • -
    • Class file server (Spark Shell only)
    • -
    - - None - Jetty-based. Each of these services starts on a random port that cannot be configured - - - +Spark makes heavy use of the network, and some environments have strict requirements for using +tight firewall settings. For a complete list of ports to configure, see the +[security page](security.html#configuring-ports-for-network-security). # High Availability By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below. -## Standby Masters with ZooKeeper +# Standby Masters with ZooKeeper **Overview** @@ -429,7 +347,7 @@ There's an important distinction to be made between "registering with a Master" Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of. -## Single-Node Recovery with Local File System +# Single-Node Recovery with Local File System **Overview** diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index aac621fe53938..40b588512ff08 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -330,6 +330,8 @@ object TestSettings { fork := true, javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dspark.ports.maxRetries=100", + javaOptions in Test += "-Dspark.ui.port=0", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index f60bbb4662af1..84b57cd2dc1af 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -102,7 +102,8 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServerPort = conf.getInt("spark.replClassServer.port", 0) + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything