From ab9128fb7ec7ca479dc91e7cc7c775e8d071eafa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Apr 2015 00:08:18 -0700 Subject: [PATCH 001/110] [SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime. There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR. [1] https://github.com/bartdag/py4j/issues/160 [2] https://github.com/bartdag/py4j/issues/161 cc rxin JoshRosen Author: Davies Liu Closes #5570 from davies/py4j_date and squashes the following commits: eb4fa53 [Davies Liu] fix tests in python 3 d17d634 [Davies Liu] rollback changes in mllib 2e7566d [Davies Liu] convert tuple into ArrayList ceb3779 [Davies Liu] Update rdd.py 3c373f3 [Davies Liu] support date and datetime by auto_convert cb094ff [Davies Liu] enable auto convert --- python/pyspark/context.py | 6 +----- python/pyspark/java_gateway.py | 15 ++++++++++++++- python/pyspark/rdd.py | 3 +++ python/pyspark/sql/_types.py | 27 +++++++++++++++++++++++++++ python/pyspark/sql/context.py | 13 ++++--------- python/pyspark/sql/dataframe.py | 18 ++++-------------- python/pyspark/sql/tests.py | 11 +++++++++++ python/pyspark/streaming/context.py | 11 +++-------- python/pyspark/streaming/kafka.py | 7 ++----- python/pyspark/streaming/tests.py | 6 +----- 10 files changed, 70 insertions(+), 47 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6a743ac8bd600..b006120eb266d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -23,8 +23,6 @@ from threading import Lock from tempfile import NamedTemporaryFile -from py4j.java_collections import ListConverter - from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -643,7 +641,6 @@ def union(self, rdds): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, allowLocal) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 45bc38f7e61f8..3cee4ea6e3a35 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,17 +17,30 @@ import atexit import os +import sys import select import signal import shlex import socket import platform from subprocess import Popen, PIPE + +if sys.version >= '3': + xrange = range + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter from pyspark.serializers import read_int +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, (list, tuple, xrange)) + +ListConverter.can_convert = can_convert_list + + def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) @@ -92,7 +105,7 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d9cdbb666f92a..d254deb527d10 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) + # There is a bug in py4j.java_gateway.JavaClass with auto_convert + # https://github.com/bartdag/py4j/issues/161 + # TODO: use auto_convert once py4j fix the bug broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 110d1152fbdb6..95fb91ad43457 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -17,6 +17,7 @@ import sys import decimal +import time import datetime import keyword import warnings @@ -30,6 +31,9 @@ long = int unicode = str +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -1237,6 +1241,29 @@ def __repr__(self): return "" % ", ".join(self) +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index acf3c114548c0..f6f107ca32d2f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -25,7 +25,6 @@ from itertools import imap as map from py4j.protocol import Py4JError -from py4j.java_collections import MapConverter from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.load(source, joptions) + df = self._ssql_ctx.load(source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, joptions) + df = self._ssql_ctx.load(source, scala_datatype, options) return DataFrame(df, self) def createExternalTable(self, tableName, path=None, source=None, @@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None, if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - joptions) + options) return DataFrame(df, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 75c181c0c7f5e..ca9bf8efb945c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,8 +25,6 @@ else: from itertools import imap as map -from py4j.java_collections import ListConverter, MapConverter - from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self.sql_ctx._sc._gateway._gateway_client) - self._jdf.saveAsTable(tableName, source, jmode, joptions) + self._jdf.saveAsTable(tableName, source, jmode, options) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) - self._jdf.save(source, jmode, joptions) + self._jdf.save(source, jmode, options) @property def schema(self): @@ -819,7 +813,6 @@ def fillna(self, value, subset=None): value = float(value) if isinstance(value, dict): - value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) @@ -932,9 +925,7 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) + jdf = self._jdf.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - return sc._jvm.PythonUtils.toSeq(jcols) + return sc._jvm.PythonUtils.toSeq(cols) def _unary_op(name, doc="unary operator"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aa3aa1d164d9f..23e84283679e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import tempfile import pickle import functools +import datetime import py4j @@ -464,6 +465,16 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4590c58839266..ac5ba69e8dbbb 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,7 +20,6 @@ import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf @@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -346,6 +342,5 @@ def union(self, *dstreams): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 7a7b6e1d9a527..8d610d6569b4a 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,8 +15,7 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError, Py4JJavaError +from py4j.java_gateway import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer @@ -57,8 +56,6 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, }) if not isinstance(topics, dict): raise TypeError("topics should be dict") - jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: @@ -66,7 +63,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 06d22154373bc..33f958a601f3a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,8 +24,6 @@ import struct from functools import reduce -from py4j.java_collections import MapConverter - from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -581,11 +579,9 @@ def test_kafka_stream(self): """Test the Python Kafka stream API.""" topic = "topic1" sendData = {"a": 3, "b": 5, "c": 10} - jSendData = MapConverter().convert(sendData, - self.ssc.sparkContext._gateway._gateway_client) self._kafkaTestUtils.createTopic(topic) - self._kafkaTestUtils.sendMessages(topic, jSendData) + self._kafkaTestUtils.sendMessages(topic, sendData) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, From 1f2f723b0daacbb9e70ec42c19a84470af1d7765 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 21 Apr 2015 00:14:16 -0700 Subject: [PATCH 002/110] [SPARK-5990] [MLLIB] Model import/export for IsotonicRegression Model import/export for IsotonicRegression Author: Yanbo Liang Closes #5270 from yanboliang/spark-5990 and squashes the following commits: 872028d [Yanbo Liang] fix code style f80ec1b [Yanbo Liang] address comments 49600cc [Yanbo Liang] address comments 429ff7d [Yanbo Liang] store each interval as a record 2b2f5a1 [Yanbo Liang] Model import/export for IsotonicRegression --- .../mllib/regression/IsotonicRegression.scala | 78 ++++++++++++++++++- .../regression/IsotonicRegressionSuite.scala | 21 +++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index cb70852e3cc8d..1d7617046b6c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch import scala.collection.mutable.ArrayBuffer +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} /** * :: Experimental :: @@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD class IsotonicRegressionModel ( val boundaries: Array[Double], val predictions: Array[Double], - val isotonic: Boolean) extends Serializable { + val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -124,6 +131,75 @@ class IsotonicRegressionModel ( predictions(foundIndex) } } + + override def save(sc: SparkContext, path: String): Unit = { + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) + } + + override protected def formatVersion: String = "1.0" +} + +object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" + + /** Model data for model import/export */ + case class Data(boundary: Double, prediction: Double) + + def save( + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean): Unit = { + val sqlContext = new SQLContext(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("isotonic" -> isotonic))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + sqlContext.createDataFrame( + boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } + ).saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(dataPath(path)) + + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("boundary", "prediction").collect() + val (boundaries, predictions) = dataArray.map { x => + (x.getDouble(0), x.getDouble(1)) + }.toList.sortBy(_._1).unzip + (boundaries.toArray, predictions.toArray) + } + } + + override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val isotonic = (metadata \ "isotonic").extract[Boolean] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) + new IsotonicRegressionModel(boundaries, predictions, isotonic) + case _ => throw new Exception( + s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)" + ) + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 7ef45248281e9..8e12340bbd9d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { @@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(model.isotonic) } + test("model save/load") { + val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) + val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0) + val model = new IsotonicRegressionModel(boundaries, predictions, true) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = IsotonicRegressionModel.load(sc, path) + assert(model.boundaries === sameModel.boundaries) + assert(model.predictions === sameModel.predictions) + assert(model.isotonic === model.isotonic) + } finally { + Utils.deleteRecursively(tempDir) + } + } + test("isotonic regression with size 0") { val model = runIsotonicRegression(Seq(), true) From 5fea3e5c36450658d8b767dd3c06dac2251a0e0c Mon Sep 17 00:00:00 2001 From: David McGuire Date: Tue, 21 Apr 2015 07:21:10 -0400 Subject: [PATCH 003/110] [SPARK-6985][streaming] Receiver maxRate over 1000 causes a StackOverflowError A simple truncation in integer division (on rates over 1000 messages / second) causes the existing implementation to sleep for 0 milliseconds, then call itself recursively; this causes what is essentially an infinite recursion, since the base case of the calculated amount of time having elapsed can't be reached before available stack space is exhausted. A fix to this truncation error is included in this patch. However, even with the defect patched, the accuracy of the existing implementation is abysmal (the error bounds of the original test were effectively [-30%, +10%], although this fact was obscured by hard-coded error margins); as such, when the error bounds were tightened down to [-5%, +5%], the existing implementation failed to meet the new, tightened, requirements. Therefore, an industry-vetted solution (from Guava) was used to get the adapted tests to pass. Author: David McGuire Closes #5559 from dmcguire81/master and squashes the following commits: d29d2e0 [David McGuire] Back out to +/-5% error margins, for flexibility in timing 8be6934 [David McGuire] Fix spacing per code review 90e98b9 [David McGuire] Address scalastyle errors 29011bd [David McGuire] Further ratchet down the error margins b33b796 [David McGuire] Eliminate dependency on even distribution by BlockGenerator 8f2934b [David McGuire] Remove arbitrary thread timing / cooperation code 70ee310 [David McGuire] Use Thread.yield(), since Thread.sleep(0) is system-dependent 82ee46d [David McGuire] Replace guard clause with nested conditional 2794717 [David McGuire] Replace the RateLimiter with the Guava implementation 38f3ca8 [David McGuire] Ratchet down the error rate to +/- 5%; tests fail 24b1bc0 [David McGuire] Fix truncation in integer division causing infinite recursion d6e1079 [David McGuire] Stack overflow error in RateLimiter on rates over 1000/s --- .../streaming/receiver/RateLimiter.scala | 33 +++---------------- .../spark/streaming/ReceiverSuite.scala | 29 +++++++++------- 2 files changed, 21 insertions(+), 41 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index e4f6ba626ebbf..97db9ded83367 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.receiver import org.apache.spark.{Logging, SparkConf} -import java.util.concurrent.TimeUnit._ +import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * @@ -33,37 +33,12 @@ import java.util.concurrent.TimeUnit._ */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private var lastSyncTime = System.nanoTime - private var messagesWrittenSinceSync = 0L private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) def waitToPush() { - if( desiredRate <= 0 ) { - return - } - val now = System.nanoTime - val elapsedNanosecs = math.max(now - lastSyncTime, 1) - val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs - if (rate < desiredRate) { - // It's okay to write; just update some variables and return - messagesWrittenSinceSync += 1 - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - messagesWrittenSinceSync = 1 - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate - val elapsedTimeInMillis = elapsedNanosecs / 1000000 - val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis - if (sleepTimeInMillis > 0) { - logTrace("Natural rate is " + rate + " per second but desired rate is " + - desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") - Thread.sleep(sleepTimeInMillis) - } - waitToPush() + if (desiredRate > 0) { + rateLimiter.acquire() } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 91261a9db7360..e7aee6eadbfc7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -158,7 +158,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 - val maxRate = 100 + val maxRate = 1001 val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) @@ -176,7 +176,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { blockGenerator.addData(count) generatedData += count count += 1 - Thread.sleep(1) } blockGenerator.stop() @@ -185,25 +184,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received") assert(recordedData.toSet === generatedData.toSet, "Received data not same") - // recordedData size should be close to the expected rate - val minExpectedMessages = expectedMessages - 3 - val maxExpectedMessages = expectedMessages + 1 + // recordedData size should be close to the expected rate; use an error margin proportional to + // the value, so that rate changes don't cause a brittle test + val minExpectedMessages = expectedMessages - 0.05 * expectedMessages + val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages val numMessages = recordedData.size assert( numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages, s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages" ) - val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 - val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 + // XXX Checking every block would require an even distribution of messages across blocks, + // which throttling code does not control. Therefore, test against the average. + val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock + val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") + + // the first and last block may be incomplete, so we slice them out + val validBlocks = recordedBlocks.drop(1).dropRight(1) + val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size + assert( - // the first and last block may be incomplete, so we slice them out - recordedBlocks.drop(1).dropRight(1).forall { block => - block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock - }, + averageBlockSize >= minExpectedMessagesPerBlock && + averageBlockSize <= maxExpectedMessagesPerBlock, s"# records in received blocks = [$receivedBlockSizes], not between " + - s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock" + s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average" ) } From c035c0f2d72f2a303b86fe0037ec43d756fff060 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 21 Apr 2015 11:01:18 -0700 Subject: [PATCH 004/110] [SPARK-5360] [SPARK-6606] Eliminate duplicate objects in serialized CoGroupedRDD CoGroupPartition, part of CoGroupedRDD, includes references to each RDD that the CoGroupedRDD narrowly depends on, and a reference to the ShuffleHandle. The partition is serialized separately from the RDD, so when the RDD and partition arrive on the worker, the references in the partition and in the RDD no longer point to the same object. This is a relatively minor performance issue (the closure can be 2x larger than it needs to be because the rdds and partitions are serialized twice; see numbers below) but is more annoying as a developer issue (this is where I ran into): if any state is stored in the RDD or ShuffleHandle on the worker side, subtle bugs can appear due to the fact that the references to the RDD / ShuffleHandle in the RDD and in the partition point to separate objects. I'm not sure if this is enough of a potential future problem to fix this old and central part of the code, so hoping to get input from others here. I did some simple experiments to see how much this effects closure size. For this example: $ val a = sc.parallelize(1 to 10).map((_, 1)) $ val b = sc.parallelize(1 to 2).map(x => (x, 2*x)) $ a.cogroup(b).collect() the closure was 1902 bytes with current Spark, and 1129 bytes after my change. The difference comes from eliminating duplicate serialization of the shuffle handle. For this example: $ val sortedA = a.sortByKey() $ val sortedB = b.sortByKey() $ sortedA.cogroup(sortedB).collect() the closure was 3491 bytes with current Spark, and 1333 bytes after my change. Here, the difference comes from eliminating duplicate serialization of the two RDDs for the narrow dependencies. The ShuffleHandle includes the ShuffleDependency, so this difference will get larger if a ShuffleDependency includes a serializer, a key ordering, or an aggregator (all set to None by default). It would also get bigger for a big RDD -- although I can't think of any examples where the RDD object gets large. The difference is not affected by the size of the function the user specifies, which (based on my understanding) is typically the source of large task closures. Author: Kay Ousterhout Closes #4145 from kayousterhout/SPARK-5360 and squashes the following commits: 85156c3 [Kay Ousterhout] Better comment the narrowDeps parameter cff0209 [Kay Ousterhout] Fixed spelling issue 658e1af [Kay Ousterhout] [SPARK-5360] Eliminate duplicate objects in serialized CoGroupedRDD --- .../org/apache/spark/rdd/CoGroupedRDD.scala | 43 +++++++++++-------- .../org/apache/spark/rdd/SubtractedRDD.scala | 30 +++++++------ 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7021a339e879b..658e8c8b89318 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle - -private[spark] sealed trait CoGroupSplitDep extends Serializable +/** The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. */ private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, + @transient rdd: RDD[_], + @transient splitIndex: Int, var split: Partition - ) extends CoGroupSplitDep { + ) extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep - -private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +/** + * Stores information about the narrow dependencies used by a CoGroupedRdd. + * + * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one + * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing + * the partition for that dependency) at the corresponding index. The size of + * narrowDeps should always be equal to the number of parents. + */ +private[spark] class CoGroupPartition( + idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -105,9 +113,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.length + val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for ((dep, depNum) <- dependencies.zipWithIndex) dep match { + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent - val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(handle) => + case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager - .getReader(handle, split.index, split.index + 1, context) + .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index e9d745588ee9a..633aeba3bbae6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + dependencies(depNum) match { + case oneToOneDependency: OneToOneDependency[_] => + val dependencyPartition = partition.narrowDeps(depNum).get.split + oneToOneDependency.rdd.iterator(dependencyPartition, context) + .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(handle) => - val iter = SparkEnv.get.shuffleManager - .getReader(handle, partition.index, partition.index + 1, context) - .read() - iter.foreach(op) + case shuffleDependency: ShuffleDependency[_, _, _] => + val iter = SparkEnv.get.shuffleManager + .getReader( + shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + .read() + iter.foreach(op) + } } + // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) + integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) + integrate(1, t => map.remove(t._1)) map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } From c25ca7c5a1f2a4f88f40b0c5cdbfa927c186cfa8 Mon Sep 17 00:00:00 2001 From: emres Date: Tue, 21 Apr 2015 16:39:56 -0400 Subject: [PATCH 005/110] SPARK-3276 Added a new configuration spark.streaming.minRememberDuration SPARK-3276 Added a new configuration parameter ``spark.streaming.minRememberDuration``, with a default value of 1 minute. So that when a Spark Streaming application is started, an arbitrary number of minutes can be taken as threshold for remembering. Author: emres Closes #5438 from emres/SPARK-3276 and squashes the following commits: 766f938 [emres] SPARK-3276 Switched to using newly added getTimeAsSeconds method. affee1d [emres] SPARK-3276 Changed the property name and variable name for minRememberDuration c9d58ca [emres] SPARK-3276 Minor code re-formatting. 1c53ba9 [emres] SPARK-3276 Started to use ssc.conf rather than ssc.sparkContext.getConf, and also getLong method directly. bfe0acb [emres] SPARK-3276 Moved the minRememberDurationMin to the class daccc82 [emres] SPARK-3276 Changed the property name to reflect the unit of value and reduced number of fields. 43cc1ce [emres] SPARK-3276 Added a new configuration parameter spark.streaming.minRemember duration, with a default value of 1 minute. --- .../streaming/dstream/FileInputDStream.scala | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 66d519171fd76..eca69f00188e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.util.{TimeStampedHashMap, Utils} @@ -63,7 +63,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * the streaming app. * - If a file is to be visible in the directory listings, it must be visible within a certain * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be * selected as the mod time will be less than the ignore threshold when it becomes visible. * - Once a file is visible, the mod time cannot change. If it does due to appends, then the * processing semantics are undefined. @@ -80,6 +80,15 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private val serializableConfOpt = conf.map(new SerializableWritable(_)) + /** + * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. + * + * Files with mod times older than this "window" of remembering will be ignored. So if new + * files are visible within this window, then the file will get selected in the next batch. + */ + private val minRememberDurationS = + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -95,7 +104,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( * This would allow us to filter away not-too-old files which have already been recently * selected and processed. */ - private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val numBatchesToRemember = FileInputDStream + .calculateNumBatchesToRemember(slideDuration, minRememberDurationS) private val durationToRemember = slideDuration * numBatchesToRemember remember(durationToRemember) @@ -330,20 +340,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private[streaming] object FileInputDStream { - /** - * Minimum duration of remembering the information of selected files. Files with mod times - * older than this "window" of remembering will be ignored. So if new files are visible - * within this window, then the file will get selected in the next batch. - */ - private val MIN_REMEMBER_DURATION = Minutes(1) - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") /** * Calculate the number of last batches to remember, such that all the files selected in - * at least last MIN_REMEMBER_DURATION duration can be remembered. + * at least last minRememberDurationS duration can be remembered. */ - def calculateNumBatchesToRemember(batchDuration: Duration): Int = { - math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt + def calculateNumBatchesToRemember(batchDuration: Duration, + minRememberDurationS: Duration): Int = { + math.ceil(minRememberDurationS.milliseconds.toDouble / batchDuration.milliseconds).toInt } } From 45c47fa4176ea75886a58f5d73c44afcb29aa629 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 21 Apr 2015 14:36:50 -0700 Subject: [PATCH 006/110] [SPARK-6845] [MLlib] [PySpark] Add isTranposed flag to DenseMatrix Since sparse matrices now support a isTransposed flag for row major data, DenseMatrices should do the same. Author: MechCoder Closes #5455 from MechCoder/spark-6845 and squashes the following commits: 525c370 [MechCoder] minor 004a37f [MechCoder] Cast boolean to int 151f3b6 [MechCoder] [WIP] Add isTransposed to pickle DenseMatrix cc0b90a [MechCoder] [SPARK-6845] Add isTranposed flag to DenseMatrix --- .../mllib/api/python/PythonMLLibAPI.scala | 13 +++-- python/pyspark/mllib/linalg.py | 49 +++++++++++++------ python/pyspark/mllib/tests.py | 16 ++++++ 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f976d2f97b043..6237b64c8f984 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -985,8 +985,10 @@ private[spark] object SerDe extends Serializable { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -994,19 +996,22 @@ private[spark] object SerDe extends Serializable { out.write(Opcodes.BINSTRING) out.write(PickleUtils.integer_to_bytes(bytes.length)) out.write(bytes) - out.write(Opcodes.TUPLE3) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) } def construct(args: Array[Object]): Object = { - if (args.length != 3) { - throw new PickleException("should be 3") + if (args.length != 4) { + throw new PickleException("should be 4") } val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) } } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index ec8c879ea9389..cc9a4cf8ba170 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -638,9 +638,10 @@ class Matrix(object): Represents a local matrix. """ - def __init__(self, numRows, numCols): + def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols + self.isTransposed = isTransposed def toArray(self): """ @@ -662,14 +663,16 @@ class DenseMatrix(Matrix): """ Column-major dense matrix. """ - def __init__(self, numRows, numCols, values): - Matrix.__init__(self, numRows, numCols) + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) def toArray(self): """ @@ -680,15 +683,23 @@ def toArray(self): array([[ 0., 2.], [ 1., 3.]]) """ - return self.values.reshape((self.numRows, self.numCols), order='F') + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') def toSparse(self): """Convert to SparseMatrix""" - indices = np.nonzero(self.values)[0] + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] colCounts = np.bincount(indices // self.numRows) colPtrs = np.cumsum(np.hstack( (0, colCounts, np.zeros(self.numCols - colCounts.size)))) - values = self.values[indices] + values = values[indices] rowIndices = indices % self.numRows return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) @@ -701,21 +712,28 @@ def __getitem__(self, indices): if j >= self.numCols or j < 0: raise ValueError("Column index %d is out of range [0, %d)" % (j, self.numCols)) - return self.values[i + j * self.numRows] + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] def __eq__(self, other): - return (isinstance(other, DenseMatrix) and - self.numRows == other.numRows and - self.numCols == other.numCols and - all(self.values == other.values)) + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) class SparseMatrix(Matrix): """Sparse Matrix stored in CSC format.""" def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False): - Matrix.__init__(self, numRows, numCols) - self.isTransposed = isTransposed + Matrix.__init__(self, numRows, numCols, isTransposed) self.colPtrs = self._convert_to_array(colPtrs, np.int32) self.rowIndices = self._convert_to_array(rowIndices, np.int32) self.values = self._convert_to_array(values, np.float64) @@ -777,8 +795,7 @@ def toArray(self): return A def toDense(self): - densevals = np.reshape( - self.toArray(), (self.numRows * self.numCols), order='F') + densevals = np.ravel(self.toArray(), order='F') return DenseMatrix(self.numRows, self.numCols, densevals) # TODO: More efficient implementation: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 849c88341a967..8f89e2cee0592 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -195,6 +195,22 @@ def test_sparse_matrix(self): self.assertEquals(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + class ListTests(PySparkTestCase): From 04bf34e34f22e3d7e972fe755251774fc6a6d52e Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 21 Apr 2015 14:43:46 -0700 Subject: [PATCH 007/110] [SPARK-7011] Build(compilation) fails with scala 2.11 option, because a protected[sql] type is accessed in ml package. [This](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala#L58) is where it is used and fails compilations at. Author: Prashant Sharma Closes #5593 from ScrapCodes/SPARK-7011/build-fix and squashes the following commits: e6d57a3 [Prashant Sharma] [SPARK-7011] Build fails with scala 2.11 option, because a protected[sql] type is accessed in ml package. --- .../src/main/scala/org/apache/spark/sql/types/dataTypes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index c6fb22c26bd3c..a108413497829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -299,7 +299,7 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[sql] object NativeType { +protected[spark] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -327,7 +327,7 @@ protected[sql] object PrimitiveType { } } -protected[sql] abstract class NativeType extends DataType { +protected[spark] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] From 2e8c6ca47df14681c1110f0736234ce76a3eca9b Mon Sep 17 00:00:00 2001 From: vidmantas zemleris Date: Tue, 21 Apr 2015 14:47:09 -0700 Subject: [PATCH 008/110] [SPARK-6994] Allow to fetch field values by name in sql.Row It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index. This tries to solve this issue. Author: vidmantas zemleris Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits: 6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row 9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType) --- .../main/scala/org/apache/spark/sql/Row.scala | 32 +++++++++ .../spark/sql/catalyst/expressions/rows.scala | 2 + .../apache/spark/sql/types/dataTypes.scala | 9 +++ .../scala/org/apache/spark/sql/RowTest.scala | 71 +++++++++++++++++++ .../spark/sql/types/DataTypeSuite.scala | 13 ++++ .../scala/org/apache/spark/sql/RowSuite.scala | 10 +++ 6 files changed, 137 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ac8a782976465..4190b7ffe1c8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -306,6 +306,38 @@ trait Row extends Serializable { */ def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName)) + + /** + * Returns the index of a given field name. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + */ + def fieldIndex(name: String): Int = { + throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.") + } + + /** + * Returns a Map(name -> value) for the requested fieldNames + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = { + fieldNames.map { name => + name -> getAs[T](name) + }.toMap + } + override def toString(): String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b6ec7d3417ef8..981373477a4bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) /** No-arg constructor for serialization. */ protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index a108413497829..7cd7bd1914c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not @@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(fields.filter(f => names.contains(f.name))) } + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala new file mode 100644 index 0000000000000..bbb9739e9cc76 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, FunSpec} + +class RowTest extends FunSpec with Matchers { + + val schema = StructType( + StructField("col1", StringType) :: + StructField("col2", StringType) :: + StructField("col3", IntegerType) :: Nil) + val values = Array("value1", "value2", 1) + + val sampleRow: Row = new GenericRowWithSchema(values, schema) + val noSchemaRow: Row = new GenericRow(values) + + describe("Row (without schema)") { + it("throws an exception when accessing by fieldName") { + intercept[UnsupportedOperationException] { + noSchemaRow.fieldIndex("col1") + } + intercept[UnsupportedOperationException] { + noSchemaRow.getAs("col1") + } + } + } + + describe("Row (with schema)") { + it("fieldIndex(name) returns field index") { + sampleRow.fieldIndex("col1") shouldBe 0 + sampleRow.fieldIndex("col3") shouldBe 2 + } + + it("getAs[T] retrieves a value by fieldname") { + sampleRow.getAs[String]("col1") shouldBe "value1" + sampleRow.getAs[Int]("col3") shouldBe 1 + } + + it("Accessing non existent field throws an exception") { + intercept[IllegalArgumentException] { + sampleRow.getAs[String]("non_existent") + } + } + + it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") { + val expected = Map( + "col1" -> "value1", + "col2" -> "value2" + ) + sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index a1341ea13d810..d797510f36685 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite { } } + test("extract field index from a StructType") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + assert(struct.fieldIndex("a") === 0) + assert(struct.fieldIndex("b") === 1) + + intercept[IllegalArgumentException] { + struct.fieldIndex("non_existent") + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index bf6cf1321a056..fb3ba4bc1b908 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -62,4 +62,14 @@ class RowSuite extends FunSuite { val de = instance.deserialize(ser).asInstanceOf[Row] assert(de === row) } + + test("get values by field name on Row created via .toDF") { + val row = Seq((1, Seq(1))).toDF("a", "b").first() + assert(row.getAs[Int]("a") === 1) + assert(row.getAs[Seq[Int]]("b") === Seq(1)) + + intercept[IllegalArgumentException]{ + row.getAs[Int]("c") + } + } } From 03fd92167107f1d061c1a7ef216468b508546ac7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 21 Apr 2015 14:48:02 -0700 Subject: [PATCH 009/110] [SQL][minor] make it more clear that we only need to re-throw GetField exception for UnresolvedAttribute For `GetField` outside `UnresolvedAttribute`, we will throw exception in `Analyzer`. Author: Wenchen Fan Closes #5588 from cloud-fan/tmp and squashes the following commits: 7ac74d2 [Wenchen Fan] small refactor --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1155dac28fc78..a986dd5387c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,12 +46,11 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => if (operator.childrenResolved) { - val nameParts = a match { - case UnresolvedAttribute(nameParts) => nameParts - case _ => Seq(a.name) + a match { + case UnresolvedAttribute(nameParts) => + // Throw errors for specific problems with get field. + operator.resolveChildren(nameParts, resolver, throwErrors = true) } - // Throw errors for specific problems with get field. - operator.resolveChildren(nameParts, resolver, throwErrors = true) } val from = operator.inputSet.map(_.name).mkString(", ") From 6265cba00f6141575b4be825735d77d4cea500ab Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 21 Apr 2015 14:48:42 -0700 Subject: [PATCH 010/110] [SPARK-6969][SQL] Refresh the cached table when REFRESH TABLE is used https://issues.apache.org/jira/browse/SPARK-6969 Author: Yin Huai Closes #5583 from yhuai/refreshTableRefreshDataCache and squashes the following commits: 1e5142b [Yin Huai] Add todo. 92b2498 [Yin Huai] Minor updates. 367df92 [Yin Huai] Recache data in the command of REFRESH TABLE. --- .../org/apache/spark/sql/sources/ddl.scala | 17 +++++++ .../spark/sql/hive/CachedTableSuite.scala | 50 ++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 2e861b84b7133..78d494184e759 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -347,7 +347,24 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) + + // If this table is cached as a InMemoryColumnarRelation, drop the original + // cached version and make the new version cached lazily. + val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + // Use lookupCachedData directly since RefreshTable also takes databaseName. + val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty + if (isCached) { + // Create a data frame to represent the table. + // TODO: Use uncacheTable once it supports database name. + val df = DataFrame(sqlContext, logicalPlan) + // Uncache the logicalPlan. + sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) + // Cache it again. + sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + } + Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index c188264072a84..fc6c3c35037b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive +import java.io.File + import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId +import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest { @@ -155,4 +158,49 @@ class CachedTableSuite extends QueryTest { assertCached(table("udfTest")) uncacheTable("udfTest") } + + test("REFRESH TABLE also needs to recache the data (data source tables)") { + val tempPath: File = Utils.createTempDir() + tempPath.delete() + table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + sql("DROP TABLE IF EXISTS refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Cache the table. + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) + // Append new data. + table("src").save(tempPath.toString, "parquet", SaveMode.Append) + // We are still using the old data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Refresh the table. + sql("REFRESH TABLE refreshTable") + // We are using the new data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + + // Drop the table and create it again. + sql("DROP TABLE refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + // Refresh the table. REFRESH TABLE command should not make a uncached + // table cached. + sql("REFRESH TABLE refreshTable") + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + + sql("DROP TABLE refreshTable") + Utils.deleteRecursively(tempPath) + } } From 2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Tue, 21 Apr 2015 14:50:02 -0700 Subject: [PATCH 011/110] [SPARK-6996][SQL] Support map types in java beans liancheng mengxr this is similar to #5146. Author: Punya Biswal Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits: d56c3e0 [Punya Biswal] Fix imports c7e308b [Punya Biswal] Support java iterable types in POJOs 5e00685 [Punya Biswal] Support map types in java beans --- .../sql/catalyst/CatalystTypeConverters.scala | 20 ++++ .../apache/spark/sql/JavaTypeInference.scala | 110 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 52 +-------- .../apache/spark/sql/JavaDataFrameSuite.java | 57 +++++++-- 4 files changed, 180 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4f9fdacda4fb..a13e2f36a1a1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import java.lang.{Iterable => JavaIterable} import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap @@ -49,6 +50,16 @@ object CatalystTypeConverters { case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (jit: JavaIterable[_], arrayType: ArrayType) => { + val iter = jit.iterator + var listOfItems: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + listOfItems :+= convertToCatalyst(item, arrayType.elementType) + } + listOfItems + } + case (s: Array[_], arrayType: ArrayType) => s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) @@ -124,6 +135,15 @@ object CatalystTypeConverters { extractOption(item) match { case a: Array[_] => a.toSeq.map(elementConverter) case s: Seq[_] => s.map(elementConverter) + case i: JavaIterable[_] => { + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter(item) + } + convertedIterable + } case null => null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala new file mode 100644 index 0000000000000..db484c5f50074 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.beans.Introspector +import java.lang.{Iterable => JIterable} +import java.util.{Iterator => JIterator, Map => JMap} + +import com.google.common.reflect.TypeToken + +import org.apache.spark.sql.types._ + +import scala.language.existentials + +/** + * Type-inference utilities for POJOs and Java collections. + */ +private [sql] object JavaTypeInference { + + private val iterableType = TypeToken.of(classOf[JIterable[_]]) + private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType + private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType + private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType + private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + + /** + * Infers the corresponding SQL data type of a Java type. + * @param typeToken Java type + * @return (SQL data type, nullable) + */ + private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + typeToken.getRawType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case _ if typeToken.isArray => + val (dataType, nullable) = inferDataType(typeToken.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ if iterableType.isAssignableFrom(typeToken) => + val (dataType, nullable) = inferDataType(elementType(typeToken)) + (ArrayType(dataType, nullable), true) + + case _ if mapType.isAssignableFrom(typeToken) => + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) + val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyDataType, _) = inferDataType(keyType) + val (valueDataType, nullable) = inferDataType(valueType) + (MapType(keyDataType, valueDataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) + } + } + + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] + val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSupertype.resolveType(iteratorReturnType) + val itemType = iteratorType.resolveType(nextReturnType) + itemType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f9f3eb2e03817..bcd20c06c6dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,6 +25,8 @@ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = inferDataType(beanClass) + val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } } - /** - * Infers the corresponding SQL data type of a Java class. - * @param clazz Java class - * @return (SQL data type, nullable) - */ - private def inferDataType(clazz: Class[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - clazz match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - - case c: Class[_] if c.isArray => - val (dataType, nullable) = inferDataType(c.getComponentType) - (ArrayType(dataType, nullable), true) - - case _ => - val beanInfo = Introspector.getBeanInfo(clazz) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - val fields = properties.map { property => - val (dataType, nullable) = inferDataType(property.getPropertyType) - new StructField(property.getName, dataType, nullable) - } - (new StructType(fields), true) - } - } } + + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 6d0fbe83c2f36..fc3ed4a708d46 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,23 +17,28 @@ package test.org.apache.spark.sql; -import java.io.Serializable; -import java.util.Arrays; - -import scala.collection.Seq; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.TestData$; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; +import org.junit.*; + +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.collection.mutable.Buffer; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; import static org.apache.spark.sql.functions.*; @@ -106,6 +111,8 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; private Integer[] b = new Integer[]{0, 1}; + private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); public double getA() { return a; @@ -114,6 +121,14 @@ public double getA() { public Integer[] getB() { return b; } + + public Map getC() { + return c; + } + + public List getD() { + return d; + } } @Test @@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals( new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); - Row first = df.select("a", "b").first(); + ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); + MapType mapType = new MapType(DataTypes.StringType, valueType, true); + Assert.assertEquals( + new StructField("c", mapType, true, Metadata.empty()), + schema.apply("c")); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } + Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Assert.assertArrayEquals( + bean.getC().get("hello"), + Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } } + } From 7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 21 Apr 2015 15:11:15 -0700 Subject: [PATCH 012/110] [SPARK-5817] [SQL] Fix bug of udtf with column names It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits --- .../sql/catalyst/analysis/Analyzer.scala | 57 ++++++++++++++++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++ .../spark/sql/catalyst/dsl/package.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 49 ++++------------ .../expressions/namedExpressions.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 8 +-- .../plans/logical/basicOperators.scala | 37 +++++++----- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 8 +-- .../org/apache/spark/sql/DataFrame.scala | 21 +++++-- .../apache/spark/sql/execution/Generate.scala | 22 ++----- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 37 ++++++------ .../org/apache/spark/sql/hive/hiveUdfs.scala | 38 +------------ ... output-0-d1f244bce64f22b34ad5bf9fd360b632 | 1 + ...mn name-0-7ac701cf43e73e9e416888e4df694348 | 0 ...mn name-1-5cdf9d51fc0e105e365d82e7611e37f3 | 0 ...mn name-2-f963396461294e06cb7cafe22a1419e4 | 3 + ...n names-0-46bdb27b3359dc81d8c246b9f69d4b82 | 0 ...n names-1-cdf6989f3b055257f1692c3bbd80dc73 | 0 ...n names-2-ab3954b69d7a991bc801a509c3166cc5 | 3 + ...mn name-0-7ac701cf43e73e9e416888e4df694348 | 0 ...mn name-1-26599718c322ff4f9740040c066d8292 | 0 ...mn name-2-f963396461294e06cb7cafe22a1419e4 | 3 + .../sql/hive/execution/HiveQuerySuite.scala | 40 ++++++++++++- 26 files changed, 207 insertions(+), 145 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb49e5ad5586f..5e42b409dcc59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -59,6 +58,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: + ResolveGenerate :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: @@ -474,8 +474,59 @@ class Analyzer( */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, _)), child) => - Generate(g, join = false, outer = false, None, child) + case Project(Seq(Alias(g: Generator, name)), child) => + Generate(g, join = false, outer = false, + qualifier = None, UnresolvedAttribute(name) :: Nil, child) + case Project(Seq(MultiAlias(g: Generator, names)), child) => + Generate(g, join = false, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), child) + } + } + + /** + * Resolve the Generate, if the output names specified, we will take them, otherwise + * we will try to provide the default names, which follow the same rule with Hive. + */ + object ResolveGenerate extends Rule[LogicalPlan] { + // Construct the output attributes for the generator, + // The output attribute names can be either specified or + // auto generated. + private def makeGeneratorOutput( + generator: Generator, + generatorOutput: Seq[Attribute]): Seq[Attribute] = { + val elementTypes = generator.elementTypes + + if (generatorOutput.length == elementTypes.length) { + generatorOutput.zip(elementTypes).map { + case (a, (t, nullable)) if !a.resolved => + AttributeReference(a.name, t, nullable)() + case (a, _) => a + } + } else if (generatorOutput.length == 0) { + elementTypes.zipWithIndex.map { + // keep the default column names as Hive does _c0, _c1, _cN + case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + } + } else { + throw new AnalysisException( + s""" + |The number of aliases supplied in the AS clause does not match + |the number of columns output by the UDTF expected + |${elementTypes.size} aliases but got ${generatorOutput.size} + """.stripMargin) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case p: Generate if p.resolved == false => + // if the generator output names are not specified, we will use the default ones. + Generate( + p.generator, + join = p.join, + outer = p.outer, + p.qualifier, + makeGeneratorOutput(p.generator, p.generatorOutput), p.child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a986dd5387c38..2381689e17525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -38,6 +38,12 @@ trait CheckAnalysis { throw new AnalysisException(msg) } + def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + exprs.flatMap(_.collect { + case e: Generator => true + }).length >= 1 + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -110,6 +116,12 @@ trait CheckAnalysis { failAnalysis( s"unresolved operator ${operator.simpleString}") + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 21c15ad14fd19..4e5c64bb63c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -284,12 +284,13 @@ package object dsl { seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) + // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join, outer, None, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 67caadb839ff9..9a6cb048af5ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -42,47 +42,30 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] - override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + // TODO ideally we should return the type of ArrayType(StructType), + // however, we don't keep the output field names in the Generator. + override def dataType: DataType = throw new UnsupportedOperationException override def nullable: Boolean = false /** - * Should be overridden by specific generators. Called only once for each instance to ensure - * that rule application does not change the output schema of a generator. + * The output element data types in structure of Seq[(DataType, Nullable)] + * TODO we probably need to add more information like metadata etc. */ - protected def makeOutput(): Seq[Attribute] - - private var _output: Seq[Attribute] = null - - def output: Seq[Attribute] = { - if (_output == null) { - _output = makeOutput() - } - _output - } + def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] - - /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { - val copy = super.makeCopy(newArgs) - copy._output = _output - copy - } } /** * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - schema: Seq[Attribute], + elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[Row], children: Seq[Expression]) - extends Generator{ - - override protected def makeOutput(): Seq[Attribute] = schema + extends Generator { override def eval(input: Row): TraversableOnce[Row] = { // TODO(davies): improve this @@ -98,30 +81,18 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(attributeNames: Seq[String], child: Expression) +case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { override lazy val resolved = child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - private lazy val elementTypes = child.dataType match { + override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - // TODO: Move this pattern into Generator. - protected def makeOutput() = - if (attributeNames.size == elementTypes.size) { - attributeNames.zip(elementTypes).map { - case (n, (t, nullable)) => AttributeReference(n, t, nullable)() - } - } else { - elementTypes.zipWithIndex.map { - case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() - } - } - override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index bcbcbeb31c7b5..afcb2ce8b9cb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any + // Alias(Generator, xx) need to be transformed into Generate(generator, ...) + override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] override def eval(input: Row): Any = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7c80634d2c852..2d03fbfb0d311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - generate @ Generate(generator, join, outer, alias, grandChild)) => + case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf grandChild.outputSet + conjunct => conjunct.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild)) + val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17522976dc2c9..bbc94a7ab3398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param alias when set, this string is applied to the schema of the output of the transformation - * as a qualifier. + * @param qualifier Qualifier for the attributes of generator(UDTF) + * @param generatorOutput The output schema of the Generator. + * @param child Children logical plan node */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - alias: Option[String], + qualifier: Option[String], + generatorOutput: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { - val output = alias - .map(a => generator.output.map(_.withQualifiers(a :: Nil))) - .getOrElse(generator.output) - if (join && outer) { - output.map(_.withNullability(true)) - } else { - output - } + override lazy val resolved: Boolean = { + generator.resolved && + childrenResolved && + generator.elementTypes.length == generatorOutput.length && + !generatorOutput.exists(!_.resolved) } - override def output: Seq[Attribute] = - if (join) child.output ++ generatorOutput else generatorOutput + // we don't want the gOutput to be taken as part of the expressions + // as that will cause exceptions like unresolved attributes etc. + override def expressions: Seq[Expression] = generator :: Nil + + def output: Seq[Attribute] = { + val qualified = qualifier.map(q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers)) + ).getOrElse(generatorOutput) + + if (join) child.output ++ qualified else qualified + } } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e10ddfdf5127c..7c249215bd6b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) - val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)()) + val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1448098c770aa..45cf695d20b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest { test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('b >= 5) && ('a > 6)) } val optimized = Optimize(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze + .generate(Explode('c_arr), true, false, Some("arr")).analyze } comparePlans(optimized, correctAnswer) } test("generate: part of conjuncts referenced generated column") { - val generator = Explode(Seq("c"), 'c_arr) + val generator = Explode('c_arr) val originalQuery = { testRelationWithArrayType .generate(generator, true, false, Some("arr")) @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('c > 6) || ('b > 5)).analyze } val optimized = Optimize(originalQuery) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 45f5da387692e..03d9834d1d131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -711,12 +711,16 @@ class DataFrame private[sql]( */ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributes = schema.toAttributes + + val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val names = schema.toAttributes.map(_.name) + val rowFunction = f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) - val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } /** @@ -733,12 +737,17 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } + val names = attributes.map(_.name) + def rowFunction(row: Row): TraversableOnce[Row] = { f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } - val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 12271048bb39c..5201e20a10565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. + * @param output the output attributes of this node, which constructed in analysis phase, + * and we can not change it, as the parent node bound with it already. */ @DeveloperApi case class Generate( generator: Generator, join: Boolean, outer: Boolean, + output: Seq[Attribute], child: SparkPlan) extends UnaryNode { - // This must be a val since the generator output expr ids are not preserved by serialization. - protected val generatorOutput: Seq[Attribute] = { - if (join && outer) { - generator.output.map(_.withNullability(true)) - } else { - generator.output - } - } - - // This must be a val since the generator output expr ids are not preserved by serialization. - override val output = - if (join) child.output ++ generatorOutput else generatorOutput - val boundGenerator = BindReferences.bindReference(generator, child.output) override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.output.size)(Literal(null)) + val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = newProjection(child.output ++ nullValues, child.output) - val joinProjection = - newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) + val joinProjection = newProjection(output, output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e687d01f57520..030ef118f75d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case logical.Generate(generator, join, outer, _, child) => - execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case g @ logical.Generate(generator, join, outer, _, _, child) => + execution.Generate( + generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7c6a7df2bd01e..c4a73b3004076 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: - ResolveUdtfsAlias :: sources.PreInsertCastAndRename :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index fd305eb480e63..85061f22772dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -725,12 +725,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = false, - Some(alias.toLowerCase), - withWhere) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = false, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + withWhere) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation @@ -833,12 +835,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - nodeToRelation(relationClause)) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + nodeToRelation(relationClause)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -1311,7 +1315,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): Generator = { + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head val attributes = nodes.flatMap { @@ -1321,7 +1325,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C function match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - Explode(attributes, nodeToExpr(child)) + (Explode(nodeToExpr(child)), attributes) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = @@ -1329,10 +1333,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUdtf( + (HiveGenericUdtf( new HiveFunctionWrapper(functionClassName), - attributes, - children.map(nodeToExpr)) + children.map(nodeToExpr)), attributes) case a: ASTNode => throw new NotImplementedError( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 47305571e579e..4b6f0ad75f54f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) + HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -266,7 +266,6 @@ private[hive] case class HiveUdaf( */ private[hive] case class HiveGenericUdtf( funcWrapper: HiveFunctionWrapper, - aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors { @@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) - protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { - field => inspectorToDataType(field.getFieldObjectInspector) - } - - override protected def makeOutput() = { - // Use column names when given, otherwise _c1, _c2, ... _cn. - if (aliasNames.size == outputDataTypes.size) { - aliasNames.zip(outputDataTypes).map { - case (attrName, attrDataType) => - AttributeReference(attrName, attrDataType, nullable = true)() - } - } else { - outputDataTypes.zipWithIndex.map { - case (attrDataType, i) => - AttributeReference(s"_c$i", attrDataType, nullable = true)() - } - } + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) } override def eval(input: Row): TraversableOnce[Row] = { @@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf( } } -/** - * Resolve Udtfs Alias. - */ -private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(projectList, _) - if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => - throw new TreeNodeException(p, "only single Generator supported for SELECT clause") - - case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) => - Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child) - case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) => - Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child) - } -} - private[hive] case class HiveUdafFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 new file mode 100644 index 0000000000000..0c7520f2090dd --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 @@ -0,0 +1,3 @@ +86 val_86 +238 val_238 +311 val_311 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 300b1f7920473..ac10b173307d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,7 +27,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ @@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("insert table with generator with column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + createQueryTest("insert table with generator with multiple column names", + """ + | CREATE TABLE gen_tmp (key Int, value String); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3; + | SELECT key, value FROM gen_tmp ORDER BY key, value ASC; + """.stripMargin) + + createQueryTest("insert table with generator without column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + test("multiple generator in projection") { + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)), key FROM src").collect() + } + + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() + } + } + createQueryTest("! operator", """ |SELECT a FROM ( @@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view2", "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") - createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") @@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view6", "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + createQueryTest("Specify the udtf output", + "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") From f83c0f112d04173f4fc2c5eaf0f9cb11d9180077 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Apr 2015 16:24:15 -0700 Subject: [PATCH 013/110] [SPARK-3386] Share and reuse SerializerInstances in shuffle paths This patch modifies several shuffle-related code paths to share and re-use SerializerInstances instead of creating new ones. Some serializers, such as KryoSerializer or SqlSerializer, can be fairly expensive to create or may consume moderate amounts of memory, so it's probably best to avoid unnecessary serializer creation in hot code paths. The key change in this patch is modifying `getDiskWriter()` / `DiskBlockObjectWriter` to accept `SerializerInstance`s instead of `Serializer`s (which are factories for instances). This allows the disk writer's creator to decide whether the serializer instance can be shared or re-used. The rest of the patch modifies several write and read paths to use shared serializers. One big win is in `ShuffleBlockFetcherIterator`, where we used to create a new serializer per received block. Similarly, the shuffle write path used to create a new serializer per file even though in many cases only a single thread would be writing to a file at a time. I made a small serializer reuse optimization in CoarseGrainedExecutorBackend as well, since it seemed like a small and obvious improvement. Author: Josh Rosen Closes #5606 from JoshRosen/SPARK-3386 and squashes the following commits: f661ce7 [Josh Rosen] Remove thread local; add comment instead 64f8398 [Josh Rosen] Use ThreadLocal for serializer instance in CoarseGrainedExecutorBackend aeb680e [Josh Rosen] [SPARK-3386] Reuse SerializerInstance in shuffle code paths --- .../executor/CoarseGrainedExecutorBackend.scala | 6 +++++- .../spark/shuffle/FileShuffleBlockManager.scala | 6 ++++-- .../org/apache/spark/storage/BlockManager.scala | 8 ++++---- .../apache/spark/storage/BlockObjectWriter.scala | 6 +++--- .../storage/ShuffleBlockFetcherIterator.scala | 6 ++++-- .../util/collection/ExternalAppendOnlyMap.scala | 6 ++---- .../spark/util/collection/ExternalSorter.scala | 14 +++++++++----- .../spark/storage/BlockObjectWriterSuite.scala | 6 +++--- 8 files changed, 34 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8300f9f2190b9..8af46f3327adb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( @@ -47,6 +48,10 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None + // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need + // to be changed so that we don't share the serializer instance across threads + private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() + override def onStart() { import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) @@ -83,7 +88,6 @@ private[spark] class CoarseGrainedExecutorBackend( logError("Received LaunchTask command but executor was null") System.exit(1) } else { - val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 5be3ed771e534..538e150ead05a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -113,11 +113,12 @@ class FileShuffleBlockManager(conf: SparkConf) private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime + val serializerInstance = serializer.newInstance() val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { @@ -133,7 +134,8 @@ class FileShuffleBlockManager(conf: SparkConf) logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) + blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, + writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1aa0ef18de118..145a9c1ae3391 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -37,7 +37,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ @@ -646,13 +646,13 @@ private[spark] class BlockManager( def getDiskWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, - writeMetrics) + new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, + syncWrites, writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 0dfc91dfaff85..14833791f7a4d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -21,7 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils @@ -71,7 +71,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, @@ -134,7 +134,7 @@ private[spark] class DiskBlockObjectWriter( ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) + objOut = serializerInstance.serializeStream(bs) initialized = true this } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8f28ef49a8a6f..f3379521d55e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -106,6 +106,8 @@ final class ShuffleBlockFetcherIterator( private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -299,7 +301,7 @@ final class ShuffleBlockFetcherIterator( // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { is0 => val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializer.newInstance().deserializeStream(is).asIterator + val iter = serializerInstance.deserializeStream(is).asIterator CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null // so we don't release it again in cleanup. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 9ff4744593d4d..30dd7f22e494f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -151,8 +151,7 @@ class ExternalAppendOnlyMap[K, V, C]( override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -179,8 +178,7 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 035f3767ff554..79a1a8a0dae38 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -272,7 +272,8 @@ private[spark] class ExternalSorter[K, V, C]( // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -308,7 +309,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -358,7 +360,9 @@ private[spark] class ExternalSorter[K, V, C]( // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, + curWriteMetrics) + writer.open() } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be @@ -749,8 +753,8 @@ private[spark] class ExternalSorter[K, V, C]( // partition and just write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter( - blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics.shuffleWriteMetrics.get) for (elem <- elements) { writer.write(elem) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 78bbc4ec2c620..003a728cb84a0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -30,7 +30,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) // Record metrics update on every write @@ -52,7 +52,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) // Record metrics update on every write @@ -75,7 +75,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.open() writer.close() From a70e849c7f9e3df5e86113d45b8c4537597cfb29 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 21 Apr 2015 16:35:37 -0700 Subject: [PATCH 014/110] [minor] [build] Set java options when generating mima ignores. The default java options make the call to GenerateMIMAIgnore take forever to run since it's gc'ing all the time. Improve things by setting the perm gen size / max heap size to larger values. Author: Marcelo Vanzin Closes #5615 from vanzin/gen-mima-fix and squashes the following commits: f44e921 [Marcelo Vanzin] [minor] [build] Set java options when generating mima ignores. --- dev/mima | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dev/mima b/dev/mima index bed5cd042634e..2952fa65d42ff 100755 --- a/dev/mima +++ b/dev/mima @@ -27,16 +27,21 @@ cd "$FWDIR" echo -e "q\n" | build/sbt oldDeps/update rm -f .generated-mima* +generate_mima_ignore() { + SPARK_JAVA_OPTS="-XX:MaxPermSize=1g -Xmx2g" \ + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +} + # Generate Mima Ignore is called twice, first with latest built jars # on the classpath and then again with previous version jars on the classpath. # Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath # it did not process the new classes (which are in assembly jar). -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? From 7fe6142cd3c39ec79899878c3deca9d5130d05b1 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 21 Apr 2015 16:42:45 -0700 Subject: [PATCH 015/110] [SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls 1. Use blas calls to find the dot product between two vectors. 2. Prevent re-computing the L2 norm of the given vector for each word in model. Author: MechCoder Closes #5467 from MechCoder/spark-6065 and squashes the following commits: dd0b0b2 [MechCoder] Preallocate wordVectors ffc9240 [MechCoder] Minor 6b74c81 [MechCoder] Switch back to native blas calls da1642d [MechCoder] Explicit types and indexing 64575b0 [MechCoder] Save indexedmap and a wordvecmat instead of matrix fbe0108 [MechCoder] Made the following changes 1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas 1350cf3 [MechCoder] [SPARK-6065] Optimize word2vec.findSynonynms using blas calls --- .../apache/spark/mllib/feature/Word2Vec.scala | 57 +++++++++++++++++-- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b2d9053f70145..98e83112f52ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils @@ -429,7 +429,36 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable with Saveable { + model: Map[String, Array[Float]]) extends Serializable with Saveable { + + // wordList: Ordered list of words obtained from model. + private val wordList: Array[String] = model.keys.toArray + + // wordIndex: Maps each word to an index, which can retrieve the corresponding + // vector from wordVectors (see below). + private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap + + // vectorSize: Dimension of each word's vector. + private val vectorSize = model.head._2.size + private val numWords = wordIndex.size + + // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word + // mapped with index i can be retrieved by the slice + // (ind * vectorSize, ind * vectorSize + vectorSize) + // wordVecNorms: Array of length numWords, each value being the Euclidean norm + // of the wordVector. + private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { + val wordVectors = new Array[Float](vectorSize * numWords) + val wordVecNorms = new Array[Double](numWords) + var i = 0 + while (i < numWords) { + val vec = model.get(wordList(i)).get + Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) + i += 1 + } + (wordVectors, wordVecNorms) + } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -443,7 +472,7 @@ class Word2VecModel private[mllib] ( override protected def formatVersion = "1.0" def save(sc: SparkContext, path: String): Unit = { - Word2VecModel.SaveLoadV1_0.save(sc, path, model) + Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } /** @@ -479,9 +508,23 @@ class Word2VecModel private[mllib] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) - model.mapValues(vec => cosineSimilarity(fVector, vec)) + val cosineVec = Array.fill[Float](numWords)(0) + val alpha: Float = 1 + val beta: Float = 0 + + blas.sgemv( + "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) + + // Need not divide with the norm of the given vector since it is constant. + val updatedCosines = new Array[Double](numWords) + var ind = 0 + while (ind < numWords) { + updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + ind += 1 + } + wordList.zip(updatedCosines) .toSeq .sortBy(- _._2) .take(num + 1) @@ -493,7 +536,9 @@ class Word2VecModel private[mllib] ( * Returns a map of words to their vector representations. */ def getVectors: Map[String, Array[Float]] = { - model + wordIndex.map { case (word, ind) => + (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) + } } } From 686dd742e11f6ad0078b7ff9b30b83a118fd8109 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 21 Apr 2015 16:44:52 -0700 Subject: [PATCH 016/110] [SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang Author: Xiangrui Meng Closes #5619 from mengxr/SPARK-7036 and squashes the following commits: dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 80e0a356bb78a..4b7d17d64e947 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -22,6 +22,7 @@ from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable +from pyspark.sql import DataFrame __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): True >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) + >>> model.predict(2, 2) + 3.8... + + >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) + >>> model = ALS.train(df, 1, nonnegative=True, seed=10) + >>> model.predict(2, 2) 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) + >>> model.predict(2, 2) 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) - >>> sameModel.predict(2,2) + >>> sameModel.predict(2, 2) 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... @@ -125,13 +131,20 @@ class ALS(object): @classmethod def _prepare(cls, ratings): - assert isinstance(ratings, RDD), "ratings should be RDD" + if isinstance(ratings, RDD): + pass + elif isinstance(ratings, DataFrame): + ratings = ratings.rdd + else: + raise TypeError("Ratings should be represented by either an RDD or a DataFrame, " + "but got %s." % type(ratings)) first = ratings.first() - if not isinstance(first, Rating): - if isinstance(first, (tuple, list)): - ratings = ratings.map(lambda x: Rating(*x)) - else: - raise ValueError("rating should be RDD of Rating or tuple/list") + if isinstance(first, Rating): + pass + elif isinstance(first, (tuple, list)): + ratings = ratings.map(lambda x: Rating(*x)) + else: + raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first)) return ratings @classmethod @@ -152,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp def _test(): import doctest import pyspark.mllib.recommendation + from pyspark.sql import SQLContext globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: From ae036d08170202074b266afd17ce34b689c70b0c Mon Sep 17 00:00:00 2001 From: Alain Date: Tue, 21 Apr 2015 16:46:17 -0700 Subject: [PATCH 017/110] [Minor][MLLIB] Fix a minor formatting bug in toString method in Node.scala add missing comma and space Author: Alain Closes #5621 from AiHe/tree-node-issue and squashes the following commits: 159a7bb [Alain] [Minor][MLLIB] Fix a minor formatting bug in toString methods in Node.scala (cherry picked from commit 4508f01890a723f80d631424ff8eda166a13a727) Signed-off-by: Xiangrui Meng --- .../src/main/scala/org/apache/spark/mllib/tree/model/Node.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 708ba04b567d3..86390a20cb5cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -52,7 +52,7 @@ class Node ( override def toString: String = { "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + "split = " + split + ", stats = " + stats + "impurity = " + impurity + ", split = " + split + ", stats = " + stats } /** From b063a61b9852cf9b9d2c905332d2ecb2fd716cc4 Mon Sep 17 00:00:00 2001 From: mweindel Date: Tue, 21 Apr 2015 20:19:33 -0400 Subject: [PATCH 018/110] Avoid warning message about invalid refuse_seconds value in Mesos >=0.21... Starting with version 0.21.0, Apache Mesos is very noisy if the filter parameter refuse_seconds is set to an invalid value like `-1`. I have seen systems with millions of log lines like ``` W0420 18:00:48.773059 32352 hierarchical_allocator_process.hpp:589] Using the default value of 'refuse_seconds' to create the refused resources filter because the input value is negative ``` in the Mesos master INFO and WARNING log files. Therefore the CoarseMesosSchedulerBackend should set the default value for refuse seconds (i.e. 5 seconds) directly. This is no problem for the fine-grained MesosSchedulerBackend, as it uses the value 1 second for this parameter. Author: mweindel Closes #5597 from MartinWeindel/master and squashes the following commits: 2f99ffd [mweindel] Avoid warning message about invalid refuse_seconds value in Mesos >=0.21. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b037a4966ced0..82f652dae0378 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -207,7 +207,7 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() + val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val slaveId = offer.getSlaveId.toString From e72c16e30d85cdc394d318b5551698885cfda9b8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 21 Apr 2015 20:33:57 -0400 Subject: [PATCH 019/110] [SPARK-6014] [core] Revamp Spark shutdown hooks, fix shutdown races. This change adds some new utility code to handle shutdown hooks in Spark. The main goal is to take advantage of Hadoop 2.x's API for shutdown hooks, which allows Spark to register a hook that will run before the one that cleans up HDFS clients, and thus avoids some races that would cause exceptions to show up and other issues such as failure to properly close event logs. Unfortunately, Hadoop 1.x does not have such APIs, so in that case correctness is still left to chance. Author: Marcelo Vanzin Closes #5560 from vanzin/SPARK-6014 and squashes the following commits: edfafb1 [Marcelo Vanzin] Better scaladoc. fcaeedd [Marcelo Vanzin] Merge branch 'master' into SPARK-6014 e7039dc [Marcelo Vanzin] [SPARK-6014] [core] Revamp Spark shutdown hooks, fix shutdown races. --- .../spark/deploy/history/HistoryServer.scala | 6 +- .../spark/deploy/worker/ExecutorRunner.scala | 12 +- .../spark/storage/DiskBlockManager.scala | 18 +-- .../spark/storage/TachyonBlockManager.scala | 24 ++-- .../scala/org/apache/spark/util/Utils.scala | 136 +++++++++++++++--- .../org/apache/spark/util/UtilsSuite.scala | 32 +++-- .../hive/thriftserver/HiveThriftServer2.scala | 9 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 9 +- .../spark/deploy/yarn/ApplicationMaster.scala | 63 ++++---- 9 files changed, 195 insertions(+), 114 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 72f6048239297..56bef57e55392 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -194,9 +194,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run(): Unit = server.stop() - }) + Utils.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7d5acabb95a48..7aa85b732fc87 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,6 +28,7 @@ import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.Utils import org.apache.spark.util.logging.FileAppender /** @@ -61,7 +62,7 @@ private[deploy] class ExecutorRunner( // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. - private var shutdownHook: Thread = null + private var shutdownHook: AnyRef = null private[worker] def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { @@ -69,12 +70,7 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - killProcess(Some("Worker shutting down")) - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } } /** @@ -106,7 +102,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Runtime.getRuntime.removeShutdownHook(shutdownHook) + Utils.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2883137872600..7ea5e54f9e1fe 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -138,25 +138,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } } - private def addShutdownHook(): Thread = { - val shutdownHook = new Thread("delete Spark local dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - DiskBlockManager.this.doStop() - } + private def addShutdownHook(): AnyRef = { + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + DiskBlockManager.this.doStop() } - Runtime.getRuntime.addShutdownHook(shutdownHook) - shutdownHook } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. - try { - Runtime.getRuntime.removeShutdownHook(shutdownHook) - } catch { - case e: IllegalStateException => None - } + Utils.removeShutdownHook(shutdownHook) doStop() } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index af873034215a9..951897cead996 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -135,21 +135,19 @@ private[spark] class TachyonBlockManager( private def addShutdownHook() { tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) - } - } catch { - case e: Exception => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + tachyonDirs.foreach { tachyonDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + Utils.deleteRecursively(tachyonDir, client) } + } catch { + case e: Exception => + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } - client.close() } - }) + client.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1029b0f9fce1e..7b0de1ae55b78 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{Properties, Locale, Random, UUID} +import java.util.{PriorityQueue, Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -30,7 +30,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} @@ -64,9 +64,15 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() + val DEFAULT_SHUTDOWN_PRIORITY = 100 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null + + private val shutdownHooks = new SparkShutdownHookManager() + shutdownHooks.install() + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -176,18 +182,16 @@ private[spark] object Utils extends Logging { private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } + addShutdownHook { () => + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) } } - }) + } // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -613,7 +617,7 @@ private[spark] object Utils extends Logging { } Utils.setupSecureURLConnection(uc, securityMgr) - val timeoutMs = + val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 uc.setConnectTimeout(timeoutMs) uc.setReadTimeout(timeoutMs) @@ -1172,7 +1176,7 @@ private[spark] object Utils extends Logging { /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler - * + * * NOTE: This method is to be called by the spark-started JVM process. */ def tryOrExit(block: => Unit) { @@ -1185,11 +1189,11 @@ private[spark] object Utils extends Logging { } /** - * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught + * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught * exception - * - * NOTE: This method is to be called by the driver-side components to avoid stopping the - * user-started JVM process completely; in contrast, tryOrExit is to be called in the + * + * NOTE: This method is to be called by the driver-side components to avoid stopping the + * user-started JVM process completely; in contrast, tryOrExit is to be called in the * spark-started JVM process . */ def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) { @@ -2132,6 +2136,102 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Utils.logUncaughtExceptions(hooks.poll().run()) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + checkState() + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + } /** diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index fb97e650ff95c..1ba99803f5a0e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.util -import scala.util.Random - import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.PriorityQueue + +import scala.collection.mutable.ListBuffer +import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -36,14 +38,14 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { - + test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) - + // Test zero assert(Utils.timeStringAsSeconds("0") === 0) - + assert(Utils.timeStringAsSeconds("1") === 1) assert(Utils.timeStringAsSeconds("1s") === 1) assert(Utils.timeStringAsSeconds("1000ms") === 1) @@ -52,7 +54,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsSeconds("1min") === TimeUnit.MINUTES.toSeconds(1)) assert(Utils.timeStringAsSeconds("1h") === TimeUnit.HOURS.toSeconds(1)) assert(Utils.timeStringAsSeconds("1d") === TimeUnit.DAYS.toSeconds(1)) - + assert(Utils.timeStringAsMs("1") === 1) assert(Utils.timeStringAsMs("1ms") === 1) assert(Utils.timeStringAsMs("1000us") === 1) @@ -61,7 +63,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1min") === TimeUnit.MINUTES.toMillis(1)) assert(Utils.timeStringAsMs("1h") === TimeUnit.HOURS.toMillis(1)) assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) - + // Test invalid strings intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") @@ -79,7 +81,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { Utils.timeStringAsMs("This 123s breaks") } } - + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -466,4 +468,18 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val newFileName = new File(testFileDir, testFileName) assert(newFileName.isFile()) } + + test("shutdown hook manager") { + val manager = new SparkShutdownHookManager() + val output = new ListBuffer[Int]() + + val hook1 = manager.add(1, () => output += 1) + manager.add(3, () => output += 3) + manager.add(2, () => output += 2) + manager.add(4, () => output += 4) + manager.remove(hook1) + + manager.runAll() + assert(output.toList === List(4, 3, 2)) + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index c3a3f8c0f41df..832596fc8bee5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.util.Utils /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -57,13 +58,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 85281c6d73a3b..7e307bb4ad1e8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -40,6 +40,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" @@ -101,13 +102,7 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f7a84207e9da6..93ae45133ce24 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -25,7 +25,6 @@ import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -95,44 +94,38 @@ private[spark] class ApplicationMaster( logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } - val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - - if (!finished) { - // This happens when the user application calls System.exit(). We have the choice - // of either failing or succeeding at this point. We report success to avoid - // retrying applications that have succeeded (System.exit(0)), which means that - // applications that explicitly exit with a non-zero status will also show up as - // succeeded in the RM UI. - finish(finalStatus, - ApplicationMaster.EXIT_SUCCESS, - "Shutdown hook called before final status was reported.") - } - if (!unregistered) { - // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { - unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) - } + Utils.addShutdownHook { () => + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // This happens when the user application calls System.exit(). We have the choice + // of either failing or succeeding at this point. We report success to avoid + // retrying applications that have succeeded (System.exit(0)), which means that + // applications that explicitly exit with a non-zero status will also show up as + // succeeded in the RM UI. + finish(finalStatus, + ApplicationMaster.EXIT_SUCCESS, + "Shutdown hook called before final status was reported.") + } + + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir(fs) } } } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the // Hadoop UGI. This has to happen before the startUserApplication which does a // doAs in order for the credentials to be passed on to the executor containers. @@ -546,8 +539,6 @@ private[spark] class ApplicationMaster( object ApplicationMaster extends Logging { - val SHUTDOWN_HOOK_PRIORITY: Int = 30 - // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 From 3134c3fe495862b7687b5aa00d3344d09cd5e08e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Apr 2015 17:49:55 -0700 Subject: [PATCH 020/110] [SPARK-6953] [PySpark] speed up python tests This PR try to speed up some python tests: ``` tests.py 144s -> 103s -41s mllib/classification.py 24s -> 17s -7s mllib/regression.py 27s -> 15s -12s mllib/tree.py 27s -> 13s -14s mllib/tests.py 64s -> 31s -33s streaming/tests.py 185s -> 84s -101s ``` Considering python3, the total saving will be 558s (almost 10 minutes) (core, and streaming run three times, mllib runs twice). During testing, it will show used time for each test file: ``` Run core tests ... Running test: pyspark/rdd.py ... ok (22s) Running test: pyspark/context.py ... ok (16s) Running test: pyspark/conf.py ... ok (4s) Running test: pyspark/broadcast.py ... ok (4s) Running test: pyspark/accumulators.py ... ok (4s) Running test: pyspark/serializers.py ... ok (6s) Running test: pyspark/profiler.py ... ok (5s) Running test: pyspark/shuffle.py ... ok (1s) Running test: pyspark/tests.py ... ok (103s) 144s ``` Author: Reynold Xin Author: Xiangrui Meng Closes #5605 from rxin/python-tests-speed and squashes the following commits: d08542d [Reynold Xin] Merge pull request #14 from mengxr/SPARK-6953 89321ee [Xiangrui Meng] fix seed in tests 3ad2387 [Reynold Xin] Merge pull request #5427 from davies/python_tests --- python/pyspark/mllib/classification.py | 17 ++--- python/pyspark/mllib/regression.py | 25 ++++--- python/pyspark/mllib/tests.py | 69 +++++++++--------- python/pyspark/mllib/tree.py | 15 ++-- python/pyspark/shuffle.py | 7 +- python/pyspark/sql/tests.py | 4 +- python/pyspark/streaming/tests.py | 63 ++++++++++------- python/pyspark/tests.py | 96 ++++++++++++++++---------- python/run-tests | 13 ++-- 9 files changed, 182 insertions(+), 127 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index eda0b60f8b1e7..a70c664a71fdb 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -86,7 +86,7 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -95,7 +95,7 @@ class LogisticRegressionModel(LinearClassificationModel): [1, 0] >>> lrm.clearThreshold() >>> lrm.predict([0.0, 1.0]) - 0.123... + 0.279... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), @@ -103,7 +103,7 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> lrm.predict(array([0.0, 1.0])) 1 >>> lrm.predict(array([1.0, 0.0])) @@ -129,7 +129,8 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) ... ] - >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) + >>> data = sc.parallelize(multi_class_data) + >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) >>> mcm.predict([0.0, 0.5, 0.0]) 0 >>> mcm.predict([0.8, 0.0, 0.0]) @@ -298,7 +299,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -330,14 +331,14 @@ class SVMModel(LinearClassificationModel): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(data)) + >>> svm = SVMWithSGD.train(sc.parallelize(data), iterations=10) >>> svm.predict([1.0]) 1 >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() >>> svm.predict(array([1.0])) - 1.25... + 1.44... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), @@ -345,7 +346,7 @@ class SVMModel(LinearClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) + >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> svm.predict(SparseVector(2, {1: 1.0})) 1 >>> svm.predict(SparseVector(2, {0: -1.0})) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index a0117c57133ae..4bc6351bdf02f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -108,7 +108,8 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=np.array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=np.array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -135,12 +136,13 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", ... intercept=True, validateData=True) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 @@ -238,7 +240,7 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -265,12 +267,13 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 @@ -321,7 +324,8 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -348,12 +352,13 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 @@ -396,7 +401,7 @@ def _test(): from pyspark import SparkContext import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8f89e2cee0592..1b008b93bc137 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,7 @@ else: import unittest +from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices @@ -47,7 +48,6 @@ from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False try: @@ -58,6 +58,12 @@ pass ser = PickleSerializer() +sc = SparkContext('local[4]', "MLlib tests") + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc def _squared_distance(a, b): @@ -67,7 +73,7 @@ def _squared_distance(a, b): return b.squared_distance(a) -class VectorTests(PySparkTestCase): +class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) @@ -212,7 +218,7 @@ def test_dense_matrix_is_transposed(self): self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) -class ListTests(PySparkTestCase): +class ListTests(MLlibTestCase): """ Test MLlib algorithms on plain lists, to make sure they're passed through @@ -255,7 +261,7 @@ def test_gmm(self): [-6, -7], ]) clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=100, seed=56) + maxIterations=10, seed=56) labels = clusters.predict(data).collect() self.assertEquals(labels[0], labels[1]) self.assertEquals(labels[2], labels[3]) @@ -266,9 +272,9 @@ def test_gmm_deterministic(self): y = range(0, 100, 10) data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEquals(round(c1, 7), round(c2, 7)) @@ -287,13 +293,13 @@ def test_classification(self): temp_dir = tempfile.mkdtemp() - lr_model = LogisticRegressionWithSGD.train(rdd) + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - svm_model = SVMWithSGD.train(rdd) + svm_model = SVMWithSGD.train(rdd, iterations=10) self.assertTrue(svm_model.predict(features[0]) <= 0) self.assertTrue(svm_model.predict(features[1]) > 0) self.assertTrue(svm_model.predict(features[2]) <= 0) @@ -307,7 +313,7 @@ def test_classification(self): categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) @@ -319,7 +325,8 @@ def test_classification(self): self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) @@ -331,7 +338,7 @@ def test_classification(self): self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) @@ -360,19 +367,19 @@ def test_regression(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LinearRegressionWithSGD.train(rdd) + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - lasso_model = LassoWithSGD.train(rdd) + lasso_model = LassoWithSGD.train(rdd, iterations=10) self.assertTrue(lasso_model.predict(features[0]) <= 0) self.assertTrue(lasso_model.predict(features[1]) > 0) self.assertTrue(lasso_model.predict(features[2]) <= 0) self.assertTrue(lasso_model.predict(features[3]) > 0) - rr_model = RidgeRegressionWithSGD.train(rdd) + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(rr_model.predict(features[0]) <= 0) self.assertTrue(rr_model.predict(features[1]) > 0) self.assertTrue(rr_model.predict(features[2]) <= 0) @@ -380,35 +387,35 @@ def test_regression(self): categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100, seed=1) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) self.assertTrue(rf_model.predict(features[3]) > 0) gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) except ValueError: self.fail() -class StatTests(PySparkTestCase): +class StatTests(MLlibTestCase): # SPARK-4023 def test_col_with_different_rdds(self): # numpy @@ -438,7 +445,7 @@ def test_col_norms(self): self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) -class VectorUDTTests(PySparkTestCase): +class VectorUDTTests(MLlibTestCase): dv0 = DenseVector([]) dv1 = DenseVector([1.0, 2.0]) @@ -472,7 +479,7 @@ def test_infer_schema(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): +class SciPyTests(MLlibTestCase): """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, @@ -613,7 +620,7 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) -class ChiSqTestTests(PySparkTestCase): +class ChiSqTestTests(MLlibTestCase): def test_goodness_of_fit(self): from numpy import inf @@ -711,13 +718,13 @@ def test_right_number_of_results(self): self.assertIsNotNone(chi[1000]) -class SerDeTest(PySparkTestCase): +class SerDeTest(MLlibTestCase): def test_to_java_object_rdd(self): # SPARK-6660 data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) self.assertEqual(_to_java_object_rdd(data).count(), 10) -class FeatureTest(PySparkTestCase): +class FeatureTest(MLlibTestCase): def test_idf_model(self): data = [ Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), @@ -730,13 +737,8 @@ def test_idf_model(self): self.assertEqual(len(idf), 11) -class Word2VecTests(PySparkTestCase): +class Word2VecTests(MLlibTestCase): def test_word2vec_setters(self): - data = [ - ["I", "have", "a", "pen"], - ["I", "like", "soccer", "very", "much"], - ["I", "live", "in", "Tokyo"] - ] model = Word2Vec() \ .setVectorSize(2) \ .setLearningRate(0.01) \ @@ -765,7 +767,7 @@ def test_word2vec_get_vectors(self): self.assertEquals(len(model.getVectors()), 3) -class StandardScalerTests(PySparkTestCase): +class StandardScalerTests(MLlibTestCase): def test_model_setters(self): data = [ [1.0, 2.0, 3.0], @@ -793,3 +795,4 @@ def test_model_transform(self): unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") + sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0fe6e4fabe43a..cfcbea573fd22 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -482,13 +482,13 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> - >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 300 + 30 >>> print(model) # it already has newline - TreeEnsembleModel classifier with 100 trees + TreeEnsembleModel classifier with 10 trees >>> model.predict([2.0]) 1.0 @@ -541,11 +541,12 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> data = sc.parallelize(sparse_data) + >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 102 + 12 >>> model.predict(SparseVector(2, {1: 1.0})) 1.0 >>> model.predict(SparseVector(2, {0: 1.0})) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b54baa57ec28a..1d0b16cade8bb 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self._next_limit() + batch, limit = 100, self.memory_limit chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -497,7 +497,7 @@ def sorted(self, iterator, key=None, reverse=False): break used_memory = get_used_memory() - if used_memory > self.memory_limit: + if used_memory > limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) @@ -513,13 +513,14 @@ def load(f): chunks.append(load(open(path, 'rb'))) current_chunk = [] gc.collect() + batch //= 2 limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close elif not chunks: - batch = min(batch * 2, 10000) + batch = min(int(batch * 1.5), 10000) current_chunk.sort(key=key, reverse=reverse) if not chunks: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 23e84283679e1..fe43c374f1cb1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -109,7 +109,7 @@ def setUpClass(cls): os.unlink(cls.tempdir.name) cls.sqlCtx = SQLContext(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData) + rdd = cls.sc.parallelize(cls.testData, 2) cls.df = rdd.toDF() @classmethod @@ -303,7 +303,7 @@ def test_apply_schema(self): abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" schema = _parse_schema_abstract(abstract) typedSchema = _infer_schema_type(rdd.first(), schema) - df = self.sqlCtx.applySchema(rdd, typedSchema) + df = self.sqlCtx.createDataFrame(rdd, typedSchema) r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) self.assertEqual(r, tuple(df.first())) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33f958a601f3a..5fa1e5ef081ab 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -16,14 +16,23 @@ # import os +import sys from itertools import chain import time import operator -import unittest import tempfile import struct from functools import reduce +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -31,19 +40,25 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 20 # seconds - duration = 1 + timeout = 4 # seconds + duration = .2 - def setUp(self): - class_name = self.__class__.__name__ + @classmethod + def setUpClass(cls): + class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) - self.sc = SparkContext(appName=class_name, conf=conf) - self.sc.setCheckpointDir("/tmp") - # TODO: decrease duration to speed up tests + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir("/tmp") + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop() + self.ssc.stop(False) def wait_for(self, result, n): start_time = time.time() @@ -363,13 +378,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 20 + timeout = 5 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(3, 1).count() + return dstream.window(.6, .2).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -378,7 +393,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(3, 1) + return dstream.countByWindow(.6, .2) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -387,7 +402,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(5, 1) + return dstream.countByWindow(1, .2) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -396,7 +411,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(5, 1) + return dstream.countByValueAndWindow(1, .2) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -405,7 +420,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -436,8 +451,8 @@ def test_stop_only_streaming_context(self): def test_stop_multiple_times(self): self._add_input_stream() self.ssc.start() - self.ssc.stop() - self.ssc.stop() + self.ssc.stop(False) + self.ssc.stop(False) def test_queue_stream(self): input = [list(range(i + 1)) for i in range(3)] @@ -495,10 +510,7 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) -class CheckpointTests(PySparkStreamingTestCase): - - def setUp(self): - pass +class CheckpointTests(unittest.TestCase): def test_get_or_create(self): inputd = tempfile.mkdtemp() @@ -518,12 +530,12 @@ def setup(): return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() def check_output(n): while not os.listdir(outputd): - time.sleep(0.1) + time.sleep(0.01) time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) @@ -553,12 +565,15 @@ def check_output(n): ssc.stop(True, True) time.sleep(1) - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) + ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 def setUp(self): super(KafkaStreamTests, self).setUp() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 75f39d9e75f38..ea63a396da5b8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,7 +31,6 @@ import time import zipfile import random -import itertools import threading import hashlib @@ -49,6 +48,11 @@ xrange = range basestring = str +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + from pyspark.conf import SparkConf from pyspark.context import SparkContext @@ -196,7 +200,7 @@ def test_external_sort_in_rdd(self): sc = SparkContext(conf=conf) l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 2) + rdd = sc.parallelize(l, 4) self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -300,6 +304,18 @@ def test_hash_serializer(self): hash(FlattenedValuesSerializer(PickleSerializer())) +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -371,15 +387,11 @@ def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) - def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - log4j.LogManager.getRootLogger().setLevel(old_level) + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") @@ -496,7 +508,8 @@ def test_deleting_input_files(self): filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) @@ -536,9 +549,9 @@ def test_namedtuple_in_rdd(self): self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): - N = 100000 + N = 10000 data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 270MB + bdata = self.sc.broadcast(data) # 27MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) @@ -569,7 +582,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) def test_large_closure(self): - N = 1000000 + N = 200000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEqual(N, rdd.first()) @@ -604,17 +617,18 @@ def test_zip_with_different_number_of_items(self): # different number of partitions b = self.sc.parallelize(range(100, 106), 3) self.assertRaises(ValueError, lambda: a.zip(b)) - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): rdd = self.sc.parallelize(range(1000)) @@ -877,7 +891,12 @@ def test_profiler(self): func_names = [func_name for fname, n, func_name in stat_list] self.assertTrue("heavy_foo" in func_names) + old_stdout = sys.stdout + sys.stdout = io = StringIO() self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + d = tempfile.gettempdir() self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) @@ -901,7 +920,7 @@ def show(self, id): def do_computation(self): def heavy_foo(x): - for i in range(1 << 20): + for i in range(1 << 18): x = 1 rdd = self.sc.parallelize(range(100)) @@ -1417,7 +1436,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class WorkerTests(PySparkTestCase): +class WorkerTests(ReusedPySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1432,7 +1451,10 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1), 1).foreach(sleep) + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass import threading t = threading.Thread(target=run) t.daemon = True @@ -1473,7 +1495,8 @@ def test_after_exception(self): def raise_exception(_): raise Exception() rdd = self.sc.parallelize(range(100), 1) - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) def test_after_jvm_exception(self): @@ -1484,7 +1507,8 @@ def test_after_jvm_exception(self): filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) @@ -1522,14 +1546,11 @@ def test_with_different_versions_of_python(self): rdd.count() version = sys.version_info sys.version_info = (2, 0, 0) - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) try: - self.assertRaises(Py4JJavaError, lambda: rdd.count()) + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) finally: sys.version_info = version - log4j.LogManager.getRootLogger().setLevel(old_level) class SparkSubmitTests(unittest.TestCase): @@ -1751,9 +1772,14 @@ def test_with_stop(self): def test_progress_api(self): with SparkContext() as sc: sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - t = threading.Thread(target=rdd.collect) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) t.daemon = True t.start() # wait for scheduler to start diff --git a/python/run-tests b/python/run-tests index ed3e819ef30c1..88b63b84fdc27 100755 --- a/python/run-tests +++ b/python/run-tests @@ -28,6 +28,7 @@ cd "$FWDIR/python" FAILED=0 LOG_FILE=unit-tests.log +START=$(date +"%s") rm -f $LOG_FILE @@ -35,8 +36,8 @@ rm -f $LOG_FILE rm -rf metastore warehouse function run_test() { - echo "Running test: $1" | tee -a $LOG_FILE - + echo -en "Running test: $1 ... " | tee -a $LOG_FILE + start=$(date +"%s") SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -48,6 +49,9 @@ function run_test() { echo "Had test failures; see logs." echo -en "\033[0m" # No color exit -1 + else + now=$(date +"%s") + echo "ok ($(($now - $start))s)" fi } @@ -161,9 +165,8 @@ if [ $(which pypy) ]; then fi if [[ $FAILED == 0 ]]; then - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color + now=$(date +"%s") + echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" fi # TODO: in the long-run, it would be nice to use a test runner like `nose`. From 41ef78a94105bb995bb14d15d47cbb6ca1638f62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Apr 2015 17:52:52 -0700 Subject: [PATCH 021/110] Closes #5427 From a0761ec7063f984dcadc8d154f83dd9cfd1c5e0b Mon Sep 17 00:00:00 2001 From: texasmichelle Date: Tue, 21 Apr 2015 18:08:29 -0700 Subject: [PATCH 022/110] [SPARK-1684] [PROJECT INFRA] Merge script should standardize SPARK-XXX prefix Cleans up the pull request title in the merge script to follow conventions outlined in the wiki under Contributing Code. https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingCode [MODULE] SPARK-XXXX: Description Author: texasmichelle Closes #5149 from texasmichelle/master and squashes the following commits: 9b6b0a7 [texasmichelle] resolved variable scope issue 7d5fa20 [texasmichelle] only prompt if title has been modified 8c195bb [texasmichelle] removed erroneous line 4f1ed46 [texasmichelle] Deque removal, logic simplifications, & prompt user to pick a title (orig or modified) df73f6a [texasmichelle] reworked regex's to enforce brackets around JIRA ref 43b5aed [texasmichelle] Merge remote-tracking branch 'apache/master' 25229c6 [texasmichelle] Merge remote-tracking branch 'apache/master' aa20a6e [texasmichelle] Move code into main() and add doctest for new text parsing method 48520ba [texasmichelle] SPARK-1684: Corrected import statement 042099d [texasmichelle] SPARK-1684 Merge script should standardize SPARK-XXX prefix 8f4a7d1 [texasmichelle] SPARK-1684 Merge script should standardize SPARK-XXX prefix --- dev/merge_spark_pr.py | 199 +++++++++++++++++++++++++++++------------- 1 file changed, 140 insertions(+), 59 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 3062e9c3c6651..b69cd15f99f63 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -55,8 +55,6 @@ # Prefix added to temporary branches BRANCH_PREFIX = "PR_TOOL" -os.chdir(SPARK_HOME) - def get_json(url): try: @@ -85,10 +83,6 @@ def continue_maybe(prompt): if result.lower() != "y": fail("Okay, exiting") - -original_head = run_cmd("git rev-parse HEAD")[:8] - - def clean_up(): print "Restoring head pointer to %s" % original_head run_cmd("git checkout %s" % original_head) @@ -101,7 +95,7 @@ def clean_up(): # merge the requested PR and return the merge hash -def merge_pr(pr_num, target_ref): +def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): pr_branch_name = "%s_MERGE_PR_%s" % (BRANCH_PREFIX, pr_num) target_branch_name = "%s_MERGE_PR_%s_%s" % (BRANCH_PREFIX, pr_num, target_ref.upper()) run_cmd("git fetch %s pull/%s/head:%s" % (PR_REMOTE_NAME, pr_num, pr_branch_name)) @@ -274,7 +268,7 @@ def get_version_json(version_str): asf_jira.transition_issue( jira_id, resolve["id"], fixVersions=jira_fix_versions, comment=comment) - print "Succesfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) + print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) def resolve_jira_issues(title, merge_branches, comment): @@ -286,68 +280,155 @@ def resolve_jira_issues(title, merge_branches, comment): resolve_jira_issue(merge_branches, comment, jira_id) -branches = get_json("%s/branches" % GITHUB_API_BASE) -branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) -# Assumes branch names can be sorted lexicographically -latest_branch = sorted(branch_names, reverse=True)[0] - -pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") -pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) -pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) +def standardize_jira_ref(text): + """ + Standardize the [SPARK-XXXXX] [MODULE] prefix + Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" + + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") + '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' + >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") + '[SPARK-4123] [PROJECT INFRA] [WIP] Show new dependencies added in pull requests' + >>> standardize_jira_ref("[MLlib] Spark 5954: Top by key") + '[SPARK-5954] [MLLIB] Top by key' + >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") + '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' + >>> standardize_jira_ref("SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") + '[SPARK-1146] [WIP] Vagrant support for Spark' + >>> standardize_jira_ref("SPARK-1032. If Yarn app fails before registering, app master stays aroun...") + '[SPARK-1032] If Yarn app fails before registering, app master stays aroun...' + >>> standardize_jira_ref("[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.") + '[SPARK-6250] [SPARK-6146] [SPARK-5911] [SQL] Types are now reserved words in DDL parser.' + >>> standardize_jira_ref("Additional information for users building from source code") + 'Additional information for users building from source code' + """ + jira_refs = [] + components = [] + + # If the string is compliant, no need to process any further + if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): + return text + + # Extract JIRA ref(s): + pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) + for ref in pattern.findall(text): + # Add brackets, replace spaces with a dash, & convert to uppercase + jira_refs.append('[' + re.sub(r'\s+', '-', ref.upper()) + ']') + text = text.replace(ref, '') + + # Extract spark component(s): + # Look for alphanumeric chars, spaces, dashes, periods, and/or commas + pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + for component in pattern.findall(text): + components.append(component.upper()) + text = text.replace(component, '') + + # Cleanup any remaining symbols: + pattern = re.compile(r'^\W+(.*)', re.IGNORECASE) + if (pattern.search(text) is not None): + text = pattern.search(text).groups()[0] + + # Assemble full text (JIRA ref(s), module(s), remaining text) + clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() + + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included + clean_text = re.sub(r'\s+', ' ', clean_text.strip()) + + return clean_text + +def main(): + global original_head + + os.chdir(SPARK_HOME) + original_head = run_cmd("git rev-parse HEAD")[:8] + + branches = get_json("%s/branches" % GITHUB_API_BASE) + branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) + # Assumes branch names can be sorted lexicographically + latest_branch = sorted(branch_names, reverse=True)[0] + + pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) + pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) + + url = pr["url"] + + # Decide whether to use the modified title or not + modified_title = standardize_jira_ref(pr["title"]) + if modified_title != pr["title"]: + print "I've re-written the title as follows to match the standard format:" + print "Original: %s" % pr["title"] + print "Modified: %s" % modified_title + result = raw_input("Would you like to use the modified title? (y/n): ") + if result.lower() == "y": + title = modified_title + print "Using modified title:" + else: + title = pr["title"] + print "Using original title:" + print title + else: + title = pr["title"] -url = pr["url"] -title = pr["title"] -body = pr["body"] -target_ref = pr["base"]["ref"] -user_login = pr["user"]["login"] -base_ref = pr["head"]["ref"] -pr_repo_desc = "%s/%s" % (user_login, base_ref) + body = pr["body"] + target_ref = pr["base"]["ref"] + user_login = pr["user"]["login"] + base_ref = pr["head"]["ref"] + pr_repo_desc = "%s/%s" % (user_login, base_ref) -# Merged pull requests don't appear as merged in the GitHub API; -# Instead, they're closed by asfgit. -merge_commits = \ - [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + # Merged pull requests don't appear as merged in the GitHub API; + # Instead, they're closed by asfgit. + merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] -if merge_commits: - merge_hash = merge_commits[0]["commit_id"] - message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] - print "Pull request %s has already been merged, assuming you want to backport" % pr_num - commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + print "Pull request %s has already been merged, assuming you want to backport" % pr_num + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', "%s^{commit}" % merge_hash]).strip() != "" - if not commit_is_downloaded: - fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) + if not commit_is_downloaded: + fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - print "Found commit %s:\n%s" % (merge_hash, message) - cherry_pick(pr_num, merge_hash, latest_branch) - sys.exit(0) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) + sys.exit(0) -if not bool(pr["mergeable"]): - msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ - "Continue? (experts only!)" - continue_maybe(msg) + if not bool(pr["mergeable"]): + msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ + "Continue? (experts only!)" + continue_maybe(msg) -print ("\n=== Pull Request #%s ===" % pr_num) -print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( - title, pr_repo_desc, target_ref, url)) -continue_maybe("Proceed with merging pull request #%s?" % pr_num) + print ("\n=== Pull Request #%s ===" % pr_num) + print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( + title, pr_repo_desc, target_ref, url)) + continue_maybe("Proceed with merging pull request #%s?" % pr_num) -merged_refs = [target_ref] + merged_refs = [target_ref] -merge_hash = merge_pr(pr_num, target_ref) + merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) -pick_prompt = "Would you like to pick %s into another branch?" % merge_hash -while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": - merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] + pick_prompt = "Would you like to pick %s into another branch?" % merge_hash + while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] -if JIRA_IMPORTED: - if JIRA_USERNAME and JIRA_PASSWORD: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira_issues(title, merged_refs, jira_comment) + if JIRA_IMPORTED: + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira_issues(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: - print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." -else: - print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." - print "Exiting without trying to close the associated JIRA." + +if __name__ == "__main__": + import doctest + doctest.testmod() + + main() From 3a3f7100f4ead9b7ac50e9711ac50b603ebf6bea Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 21 Apr 2015 18:37:53 -0700 Subject: [PATCH 023/110] [SPARK-6490][Docs] Add docs for rpc configurations Added docs for rpc configurations and also fixed two places that should have been fixed in #5595. Author: zsxwing Closes #5607 from zsxwing/SPARK-6490-docs and squashes the following commits: 25a6736 [zsxwing] Increase the default timeout to 120s 6e37c30 [zsxwing] Update docs 5577540 [zsxwing] Use spark.network.timeout as the default timeout if it presents 4f07174 [zsxwing] Fix unit tests 1c2cf26 [zsxwing] Add docs for rpc configurations --- .../org/apache/spark/util/RpcUtils.scala | 6 ++-- .../org/apache/spark/SparkConfSuite.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 2 +- docs/configuration.md | 34 +++++++++++++++++-- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 5ae793e0e87a3..f16cc8e7e42c6 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -48,11 +48,13 @@ object RpcUtils { /** Returns the default Spark timeout to use for RPC ask operations. */ def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", "30s") seconds + conf.getTimeAsSeconds("spark.rpc.askTimeout", + conf.get("spark.network.timeout", "120s")) seconds } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", "30s") seconds + conf.getTimeAsSeconds("spark.rpc.lookupTimeout", + conf.get("spark.network.timeout", "120s")) seconds } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index d7d8014a20498..272e6af0514e4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -227,7 +227,7 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("akka deprecated configs") { val conf = new SparkConf() - assert(!conf.contains("spark.rpc.num.retries")) + assert(!conf.contains("spark.rpc.numRetries")) assert(!conf.contains("spark.rpc.retry.wait")) assert(!conf.contains("spark.rpc.askTimeout")) assert(!conf.contains("spark.rpc.lookupTimeout")) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5fbda37c7cb88..44c88b00c442a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -156,7 +156,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val conf = new SparkConf() conf.set("spark.rpc.retry.wait", "0") - conf.set("spark.rpc.num.retries", "1") + conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") diff --git a/docs/configuration.md b/docs/configuration.md index d9e9e67026cbb..d587b91124cb8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -963,8 +963,9 @@ Apart from these, the following properties are also available, and may be useful Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or - spark.shuffle.io.connectionTimeout, if they are not configured. + spark.storage.blockManagerSlaveTimeoutMs, + spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or + spark.rpc.lookupTimeout if they are not configured. @@ -982,6 +983,35 @@ Apart from these, the following properties are also available, and may be useful This is only relevant for the Spark shell. + + spark.rpc.numRetries + 3 + Number of times to retry before an RPC task gives up. + An RPC task will run at most times of this number. + + + + + spark.rpc.retry.wait + 3s + + Duration for an RPC ask operation to wait before retrying. + + + + spark.rpc.askTimeout + 120s + + Duration for an RPC ask operation to wait before timing out. + + + + spark.rpc.lookupTimeout + 120s + Duration for an RPC remote endpoint lookup operation to wait before timing out. + + + #### Scheduling From 70f9f8ff38560967f2c84de77263a5455c45c495 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 21 Apr 2015 21:04:04 -0700 Subject: [PATCH 024/110] [MINOR] Comment improvements in ExternalSorter. 1. Clearly specifies the contract/interactions for users of this class. 2. Minor fix in one doc to avoid ambiguity. Author: Patrick Wendell Closes #5620 from pwendell/cleanup and squashes the following commits: 8d8f44f [Patrick Wendell] [Minor] Comment improvements in ExternalSorter. --- .../util/collection/ExternalSorter.scala | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a1a8a0dae38..79a695fb62086 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -53,7 +53,18 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do * want to do combining, having an Ordering is more efficient than not having it. * - * At a high level, this class works as follows: + * Users interact with this class in the following way: + * + * 1. Instantiate an ExternalSorter. + * + * 2. Call insertAll() with a set of records. + * + * 3. Request an iterator() back to traverse sorted/aggregated records. + * - or - + * Invoke writePartitionedFile() to create a file containing sorted/aggregated outputs + * that can be used in Spark's sort shuffle. + * + * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, @@ -65,11 +76,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * aggregation. For each file, we track how many objects were in each partition in memory, so we * don't have to write out the partition ID for every element. * - * - When the user requests an iterator, the spilled files are merged, along with any remaining - * in-memory data, using the same sort order defined above (unless both sorting and aggregation - * are disabled). If we need to aggregate by key, we either use a total ordering from the - * ordering parameter, or read the keys with the same hash code and compare them with each other - * for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. * @@ -259,8 +270,8 @@ private[spark] class ExternalSorter[K, V, C]( * Spill our in-memory collection to a sorted file that we can merge later (normal code path). * We add this file into spilledFiles to find it later. * - * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. - * See spillToPartitionedFiles() for that code path. + * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() + * is used to write files for each partition. * * @param collection whichever collection we're using (map or buffer) */ From 607eff0edfc10a1473fa9713a0500bf09f105c13 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Apr 2015 21:44:44 -0700 Subject: [PATCH 025/110] [SPARK-6113] [ML] Small cleanups after original tree API PR This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR --- .../examples/ml/DecisionTreeExample.scala | 25 ++++++++++++++----- .../spark/ml/impl/tree/treeParams.scala | 4 +-- .../org/apache/spark/ml/tree/Split.scala | 7 +++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 921b396e799e7..2cd515c89d3d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame} * {{{ * ./bin/run-example ml.DecisionTreeExample [options] * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object DecisionTreeExample { @@ -70,7 +77,7 @@ object DecisionTreeExample { val parser = new OptionParser[Params]("DecisionTreeExample") { head("DecisionTreeExample: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") .action((x, c) => c.copy(algo = x)) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") @@ -222,18 +229,23 @@ object DecisionTreeExample { // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { - val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) stages += labelIndexer } // (2) Identify categorical features using VectorIndexer. // Features with more than maxCategories values will be treated as continuous. - val featuresIndexer = new VectorIndexer().setInputCol("features") - .setOutputCol("indexedFeatures").setMaxCategories(10) + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) stages += featuresIndexer // (3) Learn DecisionTree val dt = algo match { case "classification" => - new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -242,7 +254,8 @@ object DecisionTreeExample { .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) case "regression" => - new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d033..eb2609faef05a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f62990ed..708c769087dd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold From bdc5c16e76c5d0bc147408353b2ba4faa8e914fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 21 Apr 2015 22:34:31 -0700 Subject: [PATCH 026/110] [SPARK-6889] [DOCS] CONTRIBUTING.md updates to accompany contribution doc updates Part of the SPARK-6889 doc updates, to accompany wiki updates at https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark See draft text at https://docs.google.com/document/d/1tB9-f9lmxhC32QlOo4E8Z7eGDwHx1_Q3O8uCmRXQTo8/edit# Author: Sean Owen Closes #5623 from srowen/SPARK-6889 and squashes the following commits: 03773b1 [Sean Owen] Part of the SPARK-6889 doc updates, to accompany wiki updates at https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark --- CONTRIBUTING.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b6c6b050fa331..f10d7e277eea3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,16 @@ ## Contributing to Spark -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. +*Before opening a pull request*, review the +[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +It lists steps that are required before creating a PR. In particular, consider: + +- Is the change important and ready enough to ask the community to spend time reviewing? +- Have you searched for existing, related JIRAs and pull requests? +- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is the change being proposed clearly explained and motivated? -Please see the [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. +When you contribute code, you affirm that the contribution is your original work and that you +license the work to the project under the project's open source license. Whether or not you +state this explicitly, by submitting any copyrighted material via pull request, email, or +other means you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. From 33b85620f910c404873d362d27cca1223084913a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 22 Apr 2015 11:08:59 -0700 Subject: [PATCH 027/110] [SPARK-7052][Core] Add ThreadUtils and move thread methods from Utils to ThreadUtils As per rxin 's suggestion in https://github.com/apache/spark/pull/5392/files#r28757176 What's more, there is a race condition in the global shared `daemonThreadFactoryBuilder`. `daemonThreadFactoryBuilder` may be modified by multiple threads. This PR removed the global `daemonThreadFactoryBuilder` and created a new `ThreadFactoryBuilder` every time. Author: zsxwing Closes #5631 from zsxwing/thread-utils and squashes the following commits: 9fe5b0e [zsxwing] Add ThreadUtils and move thread methods from Utils to ThreadUtils --- .../spark/ExecutorAllocationManager.scala | 8 +-- .../org/apache/spark/HeartbeatReceiver.scala | 11 ++- .../deploy/history/FsHistoryProvider.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 7 +- .../spark/network/nio/ConnectionManager.scala | 12 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../cluster/YarnSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 8 +-- .../storage/BlockManagerMasterEndpoint.scala | 4 +- .../storage/BlockManagerSlaveEndpoint.scala | 4 +- .../org/apache/spark/util/ThreadUtils.scala | 67 +++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 29 -------- .../apache/spark/util/ThreadUtilsSuite.scala | 57 ++++++++++++++++ .../streaming/kafka/KafkaInputDStream.scala | 5 +- .../kafka/ReliableKafkaReceiver.scala | 4 +- .../receiver/ReceivedBlockHandler.scala | 4 +- .../streaming/util/WriteAheadLogManager.scala | 4 +- 19 files changed, 170 insertions(+), 76 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ThreadUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 4e7bf51fc0622..b986fa87dc2f4 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,12 +17,12 @@ package org.apache.spark -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.TimeUnit import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -132,8 +132,8 @@ private[spark] class ExecutorAllocationManager( private val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. - private val executor = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + private val executor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-dynamic-executor-allocation") /** * Verify that the settings specified through the config are valid. diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index e3bd16f1cbf24..68d05d5b02537 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable @@ -25,7 +25,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -76,11 +76,10 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + private val timeoutCheckingThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") - private val killExecutorThread = Executors.newSingleThreadExecutor( - Utils.namedThreadFactory("kill-executor-thread")) + private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 47bdd7749ec3d..9847d5944a390 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -99,7 +99,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private val replayExecutor: ExecutorService = { if (!conf.contains("spark.testing")) { - Executors.newSingleThreadExecutor(Utils.namedThreadFactory("log-replay-executor")) + ThreadUtils.newDaemonSingleThreadExecutor("log-replay-executor") } else { MoreExecutors.sameThreadExecutor() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 327d155b38c22..5fc04df5d6a40 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -76,7 +76,7 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") private val executorSource = new ExecutorSource(threadPool, executorId) if (!isLocal) { @@ -110,8 +110,7 @@ private[spark] class Executor( private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] // Executor for the heartbeat task. - private val heartbeater = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("driver-heartbeater")) + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") startDriverHeartbeater() diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 1a68e621eaee7..16e905982cf64 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -36,7 +36,7 @@ import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import scala.util.Try import scala.util.control.NonFatal @@ -79,7 +79,7 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = - new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) + new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", @@ -102,7 +102,7 @@ private[nio] class ConnectionManager( handlerThreadCount, conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) { + ThreadUtils.namedThreadFactory("handle-message-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -117,7 +117,7 @@ private[nio] class ConnectionManager( ioThreadCount, conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) { + ThreadUtils.namedThreadFactory("handle-read-write-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -134,7 +134,7 @@ private[nio] class ConnectionManager( connectThreadCount, conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) { + ThreadUtils.namedThreadFactory("handle-connect-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -160,7 +160,7 @@ private[nio] class ConnectionManager( private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("Connection manager future execution context")) + ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) @volatile private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a32f8936fb0e..8c4bff4e83afc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} @@ -129,7 +129,7 @@ class DAGScheduler( private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) private val messageScheduler = - Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 3938580aeea59..391827c1d2156 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. @@ -35,7 +35,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( THREADS, "task-result-getter") protected val serializer = new ThreadLocal[SerializerInstance] { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 63987dfb32695..9656fb76858ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -26,7 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -73,7 +73,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = - Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") override def onStart() { // Periodically revive offers to allow delay scheduling to work diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1406a36a669c5..d987c7d563579 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, RpcUtils} import scala.util.control.NonFatal @@ -97,7 +97,7 @@ private[spark] abstract class YarnSchedulerBackend( private var amEndpoint: Option[RpcEndpointRef] = None private val askAmThreadPool = - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 50ba0b9d5a612..ac5b524517818 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,14 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.TimeUnit import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private case class ReviveOffers() @@ -47,8 +47,8 @@ private[spark] class LocalEndpoint( private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { - private val reviveThread = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("local-revive-thread")) + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") private var freeCores = totalCores diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 28c73a7d543ff..4682167912ff0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses @@ -51,7 +51,7 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 8980fa8eb70e2..543df4e1350dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -36,7 +36,7 @@ class BlockManagerSlaveEndpoint( extends RpcEndpoint with Logging { private val asyncThreadPool = - Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala new file mode 100644 index 0000000000000..098a4b79496b2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent._ + +import com.google.common.util.concurrent.ThreadFactoryBuilder + +private[spark] object ThreadUtils { + + /** + * Create a thread factory that names threads with a prefix and also sets the threads to daemon. + */ + def namedThreadFactory(prefix: String): ThreadFactory = { + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(prefix + "-%d").build() + } + + /** + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newSingleThreadExecutor. + */ + def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadExecutor(threadFactory) + } + + /** + * Wrapper over newSingleThreadScheduledExecutor. + */ + def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadScheduledExecutor(threadFactory) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7b0de1ae55b78..2feb7341b159b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,7 +35,6 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} import com.google.common.net.InetAddresses -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} @@ -897,34 +896,6 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = - new ThreadFactoryBuilder().setDaemon(true) - - /** - * Create a thread factory that names threads with a prefix and also sets the threads to daemon. - */ - def namedThreadFactory(prefix: String): ThreadFactory = { - daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() - } - - /** - * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] - } - - /** - * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] - } - /** * Return the string to tell how long has passed in milliseconds. */ diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala new file mode 100644 index 0000000000000..a3aa3e953fbec --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.scalatest.FunSuite + +class ThreadUtilsSuite extends FunSuite { + + test("newDaemonSingleThreadExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") + @volatile var threadName = "" + executor.submit(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + } + }) + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } + + test("newDaemonSingleThreadScheduledExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("this-is-a-thread-name") + try { + val latch = new CountDownLatch(1) + @volatile var threadName = "" + executor.schedule(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + latch.countDown() + } + }, 1, TimeUnit.MILLISECONDS) + latch.await(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } finally { + executor.shutdownNow() + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 4d26b640e8d74..cca0fac0234e1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * Input stream that pulls messages from a Kafka Broker. @@ -111,7 +111,8 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") + val executorPool = + ThreadUtils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index c4a44c1822c39..ea87e960379f1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -33,7 +33,7 @@ import org.I0Itec.zkclient.ZkClient import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. @@ -121,7 +121,7 @@ class ReliableKafkaReceiver[ zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) - messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + messageHandlerThreadPool = ThreadUtils.newDaemonFixedThreadPool( topics.values.sum, "KafkaMessageHandler") blockGenerator.start() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index dcdc27d29c270..297bf04c0c25e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -150,7 +150,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + ThreadUtils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) /** * This implementation stores the block into the block manager as well as a write ahead log. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala index 6bdfe45dc7f83..38a93cc3c9a1f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.Logging -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} import WriteAheadLogManager._ /** @@ -60,7 +60,7 @@ private[streaming] class WriteAheadLogManager( if (callerName.nonEmpty) s" for $callerName" else "" private val threadpoolName = s"WriteAheadLogManager $callerNameTag" implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(1, threadpoolName)) + ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None From cdf0328684f70ddcd49b23c23c1532aeb9caa44e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 11:18:01 -0700 Subject: [PATCH 028/110] [SQL] Rename some apply functions. I was looking at the code gen code and got confused by a few of use cases of apply, in particular apply on objects. So I went ahead and changed a few of them. Hopefully slightly more clear with a proper verb. Author: Reynold Xin Closes #5624 from rxin/apply-rename and squashes the following commits: ee45034 [Reynold Xin] [SQL] Rename some apply functions. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 6 +-- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateOrdering.scala | 2 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 6 +-- .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../spark/sql/catalyst/SqlParserSuite.scala | 9 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 22 ++++---- .../analysis/DecimalPrecisionSuite.scala | 6 +-- .../GeneratedEvaluationSuite.scala | 10 ++-- .../GeneratedMutableEvaluationSuite.scala | 8 +-- .../BooleanSimplificationSuite.scala | 2 +- .../optimizer/CombiningLimitsSuite.scala | 4 +- .../optimizer/ConstantFoldingSuite.scala | 14 ++--- .../ConvertToLocalRelationSuite.scala | 2 +- .../ExpressionOptimizationSuite.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 52 +++++++++---------- .../optimizer/LikeSimplificationSuite.scala | 8 +-- .../catalyst/optimizer/OptimizeInSuite.scala | 4 +- ...mplifyCaseConversionExpressionsSuite.scala | 8 +-- .../optimizer/UnionPushdownSuite.scala | 7 ++- .../catalyst/trees/RuleExecutorSuite.scala | 6 +-- .../org/apache/spark/sql/SQLContext.scala | 12 ++--- .../spark/sql/execution/SparkPlan.scala | 10 ++-- .../joins/BroadcastNestedLoopJoin.scala | 2 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 2 +- .../apache/spark/sql/parquet/newParquet.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 35 files changed, 117 insertions(+), 117 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 3823584287741..1f3c02478bd68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -32,7 +32,7 @@ private[sql] object KeywordNormalizer { private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = { // Initialize the Keywords. lexical.initialize(reservedWords) phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4e5c64bb63c9f..5d5aba9644ff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -296,7 +296,7 @@ package object dsl { InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } object plans { // scalastyle:ignore diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index be2c101d63a63..eeffedb558c1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -98,11 +98,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin }) /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = - apply(bind(expressions, inputSchema)) + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** * Returns a term name that is unique within this instance of a `CodeGenerator`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index a419fd7ecb39b..840260703ab74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -30,7 +30,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index fc2a2b60703e4..b129c0d898bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -30,7 +30,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = - in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 2a0935c790cf3..40e163024360e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -26,7 +26,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 6f572ff959fb4..d491babc2bff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -31,7 +31,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fcd6352079b4d..46522eb9c1264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { - def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = - apply(BindReferences.bindReference(expression, inputSchema)) + def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + create(BindReferences.bindReference(expression, inputSchema)) - def apply(expression: Expression): (Row => Boolean) = { + def create(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index c441f0bf24d85..3f9858b0c4a43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. */ - def apply(plan: TreeType): TreeType = { + def execute(plan: TreeType): TreeType = { var curPlan = plan batches.foreach { batch => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 1a0a0e6154ad2..a652c70560990 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -49,13 +49,14 @@ class SqlParserSuite extends FunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + assert(TestCommand("NotRealCommand") === + parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) } test("test case insensitive") { val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7c249215bd6b6..971e1ff5ec2b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -42,10 +42,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { def caseSensitiveAnalyze(plan: LogicalPlan): Unit = - caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = - caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -82,7 +82,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer(plan).resolved) + assert(caseInsensitiveAnalyzer.execute(plan).resolved) } test("check project's resolved") { @@ -98,11 +98,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("analyze project") { assert( - caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyzer( + caseSensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -115,13 +115,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyzer( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyzer( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -134,13 +134,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } def errorTest( @@ -219,7 +219,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer( + val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, 'a / 'b as 'div2, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 67bec999dfbd1..36b03d1c65e28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -48,12 +48,12 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { private def checkType(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Seq(Alias(expression, "c")()), relation) - assert(analyzer(plan).schema.fields(0).dataType === expectedType) + assert(analyzer.execute(plan).schema.fields(0).dataType === expectedType) } private def checkComparison(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Alias(expression, "c")() :: Nil, relation) - val comparison = analyzer(plan).collect { + val comparison = analyzer.execute(plan).collect { case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e }.head assert(comparison.left.dataType === expectedType) @@ -64,7 +64,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val plan = Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) - val (l, r) = analyzer(plan).collect { + val (l, r) = analyzer.execute(plan).collect { case Union(left, right) => (left.output.head, right.output.head) }.head assert(l.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index ef3114fd4dbab..b5ebe4b38e337 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -29,7 +29,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = try { - GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => val evaluated = GenerateProjection.expressionEvaluator(expression) @@ -56,10 +56,10 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { val futures = (1 to 20).map { _ => future { - GeneratePredicate(EqualTo(Literal(1), Literal(1))) - GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) + GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index bcc0c404d2cfb..97af2e0fd0502 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -25,13 +25,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { lazy val evaluated = GenerateProjection.expressionEvaluator(expression) val plan = try { - GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) } catch { case e: Throwable => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 72f06e26e05f1..6255578d7fa57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -61,7 +61,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze - val actual = Optimize(plan).expressions.head + val actual = Optimize.execute(plan).expressions.head compareConditions(actual, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index e2ae0d25db1a5..2d16d668fd522 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -44,7 +44,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(10) .limit(5) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(7) .limit(5) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 4396bd0dda9a9..14b28e8402610 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -47,7 +47,7 @@ class ConstantFoldingSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -74,7 +74,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2)) )(Literal(9) / Literal(3) as Symbol("9/3")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -99,7 +99,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * 'a + Literal(4) as Symbol("c3"), 'a * (Literal(3) + Literal(4)) as Symbol("c4")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -127,7 +127,7 @@ class ConstantFoldingSuite extends PlanTest { (Literal(1) === Literal(1) || 'b > 1) && (Literal(1) === Literal(2) || 'b < 10))) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -144,7 +144,7 @@ class ConstantFoldingSuite extends PlanTest { Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -163,7 +163,7 @@ class ConstantFoldingSuite extends PlanTest { Rand + Literal(1) as Symbol("c1"), Sum('a) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -210,7 +210,7 @@ class ConstantFoldingSuite extends PlanTest { Contains("abc", Literal.create(null, StringType)) as 'c20 ) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index cf42d43823399..6841bd9890c97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -49,7 +49,7 @@ class ConvertToLocalRelationSuite extends PlanTest { UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) - val optimized = Optimize(projectOnLocal.analyze) + val optimized = Optimize.execute(projectOnLocal.analyze) comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index 2f3704be59a9d..a4a3a66b8b229 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -30,7 +30,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 45cf695d20b01..aa9708b164efa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -50,7 +50,7 @@ class FilterPushdownSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -65,7 +65,7 @@ class FilterPushdownSuite extends PlanTest { .groupBy('a)('a, Count('b)) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -81,7 +81,7 @@ class FilterPushdownSuite extends PlanTest { .groupBy('a)('a as 'c, Count('b)) .select('c) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -98,7 +98,7 @@ class FilterPushdownSuite extends PlanTest { .select('a) .where('a === 1) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1) @@ -115,7 +115,7 @@ class FilterPushdownSuite extends PlanTest { .where('e === 1) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a + 'b === 1) @@ -131,7 +131,7 @@ class FilterPushdownSuite extends PlanTest { .where('a === 1) .where('a === 2) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1 && 'a === 2) @@ -152,7 +152,7 @@ class FilterPushdownSuite extends PlanTest { .where("y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -170,7 +170,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation val correctAnswer = @@ -188,7 +188,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -206,7 +206,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val correctAnswer = left.join(y, LeftOuter).where("y.b".attr === 2).analyze @@ -223,7 +223,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter).where("x.b".attr === 1).analyze @@ -240,7 +240,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('d) val correctAnswer = left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze @@ -257,7 +257,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze @@ -274,7 +274,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -292,7 +292,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze @@ -309,7 +309,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -327,7 +327,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -346,7 +346,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -365,7 +365,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 3).subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -382,7 +382,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, condition = Some("x.b".attr === "y.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) } @@ -396,7 +396,7 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.where('a === 1).subquery('y) val correctAnswer = @@ -415,7 +415,7 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) val correctAnswer = @@ -436,7 +436,7 @@ class FilterPushdownSuite extends PlanTest { ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val lleft = testRelation.where('a >= 3).subquery('z) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) @@ -457,7 +457,7 @@ class FilterPushdownSuite extends PlanTest { .generate(Explode('c_arr), true, false, Some("arr")) .where(('b >= 5) && ('a > 6)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) @@ -474,7 +474,7 @@ class FilterPushdownSuite extends PlanTest { .generate(generator, true, false, Some("arr")) .where(('b >= 5) && ('c > 6)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val referenceResult = { testRelationWithArrayType .where('b >= 5) @@ -502,7 +502,7 @@ class FilterPushdownSuite extends PlanTest { .generate(Explode('c_arr), true, false, Some("arr")) .where(('c > 6) || ('b > 5)).analyze } - val optimized = Optimize(originalQuery) + val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b10577c8001e2..b3df487c84dc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -41,7 +41,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "abc%") || ('a like "abc\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(StartsWith('a, "abc") || ('a like "abc\\%")) .analyze @@ -54,7 +54,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where('a like "%xyz") - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(EndsWith('a, "xyz")) .analyze @@ -67,7 +67,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "%mn%") || ('a like "%mn\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(Contains('a, "mn") || ('a like "%mn\\%")) .analyze @@ -80,7 +80,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "") || ('a like "abc")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(('a === "") || ('a === "abc")) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 966bc9ada1e6e..3eb399e68e70c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -49,7 +49,7 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) @@ -64,7 +64,7 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index 22992fb6f50d4..6b1e53cd42b24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -41,7 +41,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Upper('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -55,7 +55,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Lower('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -69,7 +69,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Upper('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze @@ -82,7 +82,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Lower('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a54751dfa9a12..a3ad200800b02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -41,7 +40,7 @@ class UnionPushdownSuite extends PlanTest { test("union: filter to each side") { val query = testUnion.where('a === 1) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze @@ -52,7 +51,7 @@ class UnionPushdownSuite extends PlanTest { test("union: project to each side") { val query = testUnion.select('b) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.select('b), testRelation2.select('e)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 4b2d45584045f..2a641c63f87bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -34,7 +34,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("once", Once, DecrementLiterals) :: Nil } - assert(ApplyOnce(Literal(10)) === Literal(9)) + assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { @@ -42,7 +42,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(10)) === Literal(0)) + assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { @@ -50,6 +50,6 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(100)) === Literal(90)) + assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bcd20c06c6dca..a279b0f07c38a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -132,16 +132,16 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_)) + protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql, false).getOrElse(sqlParser(sql)) + ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) @@ -1120,12 +1120,12 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] class QueryExecution(val logical: LogicalPlan) { def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - lazy val analyzed: LogicalPlan = analyzer(logical) + lazy val analyzed: LogicalPlan = analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { @@ -1134,7 +1134,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[Row] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index e159ffe66cb24..59c89800da00f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -144,7 +144,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection(expressions, inputSchema) + GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) } @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) } @@ -166,15 +166,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { if (codegenEnabled) { - GeneratePredicate(expression, inputSchema) + GeneratePredicate.generate(expression, inputSchema) } else { - InterpretedPredicate(expression, inputSchema) + InterpretedPredicate.create(expression, inputSchema) } } protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { if (codegenEnabled) { - GenerateOrdering(order, inputSchema) + GenerateOrdering.generate(order, inputSchema) } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 83b1a83765153..56200f6b8c8a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -59,7 +59,7 @@ case class BroadcastNestedLoopJoin( } @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 1fa7e7bd0406c..e06f63f94b78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -45,7 +45,7 @@ case class LeftSemiJoinBNL( override def right: SparkPlan = broadcast @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index af7b3c81ae7b2..88466f52bd4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -611,7 +611,7 @@ private[sql] case class ParquetRelation2( val rawPredicate = partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) - val boundPredicate = InterpretedPredicate(rawPredicate transform { + val boundPredicate = InterpretedPredicate.create(rawPredicate transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 78d494184e759..e7a0685e013d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -38,9 +38,9 @@ private[sql] class DDLParser( parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with DataTypeParser with Logging { - def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { try { - Some(apply(input)) + Some(parse(input)) } catch { case ddlException: DDLException => throw ddlException case _ if !exceptionOnError => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c4a73b3004076..dd06b2620c5ee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -93,7 +93,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false) + val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false) DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 85061f22772dd..0ea6d57b816c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -144,7 +144,7 @@ private[hive] object HiveQl { protected val hqlParser = { val fallback = new ExtendedHiveQlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } /** @@ -240,7 +240,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) val errorRegEx = "line (\\d+):(\\d+) (.*)".r diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a6f4fbe8aba06..be9249a8b1f44 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -119,9 +119,9 @@ private[hive] trait HiveStrategies { val inputData = new GenericMutableRow(relation.partitionKeys.size) val pruningCondition = if (codegenEnabled) { - GeneratePredicate(castedPredicate) + GeneratePredicate.generate(castedPredicate) } else { - InterpretedPredicate(castedPredicate) + InterpretedPredicate.create(castedPredicate) } val partitions = relation.hiveQlPartitions.filter { part => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6570fa1043900..9f17bca083d13 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -185,7 +185,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. - analyzer(logical) + analyzer.execute(logical) } } From fbe7106d75c6a1624d10793fba6759703bc5c6e6 Mon Sep 17 00:00:00 2001 From: szheng79 Date: Wed, 22 Apr 2015 13:02:55 -0700 Subject: [PATCH 029/110] [SPARK-7039][SQL]JDBCRDD: Add support on type NVARCHAR Issue: https://issues.apache.org/jira/browse/SPARK-7039 Add support to column type NVARCHAR in Sql Server java.sql.Types: http://docs.oracle.com/javase/7/docs/api/java/sql/Types.html Author: szheng79 Closes #5618 from szheng79/patch-1 and squashes the following commits: 10da99c [szheng79] Update JDBCRDD.scala eab0bd8 [szheng79] Add support on type NVARCHAR --- sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index b9022fcd9e3ad..8b1edec20feee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -60,6 +60,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null case java.sql.Types.NUMERIC => DecimalType.Unlimited + case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType case java.sql.Types.REF => StringType From baf865ddc2cff9b99d6aeab9861e030da511257f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 15:26:58 -0700 Subject: [PATCH 030/110] [SPARK-7059][SQL] Create a DataFrame join API to facilitate equijoin. Author: Reynold Xin Closes #5638 from rxin/joinUsing and squashes the following commits: 13e9cc9 [Reynold Xin] Code review + Python. b1bd914 [Reynold Xin] [SPARK-7059][SQL] Create a DataFrame join API to facilitate equijoin and self join. --- python/pyspark/sql/dataframe.py | 9 ++++- .../org/apache/spark/sql/DataFrame.scala | 37 +++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 40 ++++++++++++++----- 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ca9bf8efb945c..c8c30ce4022c8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -459,16 +459,23 @@ def join(self, other, joinExprs=None, joinType=None): The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: Join expression + :param joinExprs: a string for join column name, or a join expression (Column). + If joinExprs is a string indicating the name of the join column, + the column must exist on both sides, and this performs an inner equi-join. :param joinType: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + + >>> df.join(df2, 'name').select(df.name, df2.height).collect() + [Row(name=u'Bob', height=85)] """ if joinExprs is None: jdf = self._jdf.join(other._jdf) + elif isinstance(joinExprs, basestring): + jdf = self._jdf.join(other._jdf, joinExprs) else: assert isinstance(joinExprs, Column), "joinExprs should be Column" if joinType is None: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 03d9834d1d131..ca6ae482eb2ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -342,6 +342,43 @@ class DataFrame private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } + /** + * Inner equi-join with another [[DataFrame]] using the given column. + * + * Different from other join functions, the join column will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the column "user_id" + * df1.join(df2, "user_id") + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumn Name of the column to join on. This column must exist on both sides. + * @group dfops + */ + def join(right: DataFrame, usingColumn: String): DataFrame = { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sqlContext.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + + // Project only one of the join column. + val joinedCol = joined.right.resolve(usingColumn) + Project( + joined.output.filterNot(_ == joinedCol), + Join( + joined.left, + joined.right, + joinType = Inner, + Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) + ) + } + /** * Inner join with another [[DataFrame]], using the given join expression. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b9b6a400ae195..5ec06d448e50f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -109,15 +109,6 @@ class DataFrameSuite extends QueryTest { assert(testData.head(2).head.schema === testData.schema) } - test("self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) - } - test("simple explode") { val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") @@ -127,8 +118,35 @@ class DataFrameSuite extends QueryTest { ) } - test("self join with aliases") { - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + test("join - join using") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") + + checkAnswer( + df.join(df2, "int"), + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) + } + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") checkAnswer( df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) From f4f39981f4f5e88c30eec7d0b107e2c3cdc268c9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 22 Apr 2015 17:22:26 -0700 Subject: [PATCH 031/110] [SPARK-6827] [MLLIB] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API Make PySpark ```FPGrowthModel.freqItemsets``` consistent with Java/Scala API like ```MatrixFactorizationModel.userFeatures``` It return a RDD with each tuple is composed of an array and a long value. I think it's difficult to implement namedtuples to wrap the output because items of freqItemsets can be any type with arbitrary length which is tedious to impelement corresponding SerDe function. Author: Yanbo Liang Closes #5614 from yanboliang/spark-6827 and squashes the following commits: da8c404 [Yanbo Liang] use namedtuple 5532e78 [Yanbo Liang] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API --- python/pyspark/mllib/fpm.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 628ccc01cf3cc..d8df02bdbaba9 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,10 @@ # limitations under the License. # +import numpy +from numpy import array +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -36,14 +40,14 @@ class FPGrowthModel(JavaModelWrapper): >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) - [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... """ def freqItemsets(self): """ - Get the frequent itemsets of this model + Returns the frequent itemsets of this model. """ - return self.call("getFreqItemsets") + return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) class FPGrowth(object): @@ -67,6 +71,11 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) + class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): + """ + Represents an (items, freq) tuple. + """ + def _test(): import doctest From 04525c077c638a7e615c294ba988e35036554f5f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 22 Apr 2015 19:14:28 -0700 Subject: [PATCH 032/110] [SPARK-6967] [SQL] fix date type convertion in jdbcrdd This pr convert java.sql.Date type into Int for JDBCRDD. Author: Daoyuan Wang Closes #5590 from adrian-wang/datebug and squashes the following commits: f897b81 [Daoyuan Wang] add a test case 3c9184c [Daoyuan Wang] fix date type convertion in jdbcrdd --- .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 4 ++-- .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 8b1edec20feee..b975191d41963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -350,8 +350,8 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - // TODO(davies): convert Date into Int - case DateConversion => mutableRow.update(i, rs.getDate(pos)) + case DateConversion => + mutableRow.update(i, DateUtils.fromJavaDate(rs.getDate(pos))) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3596b183d4328..db096af4535a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -249,6 +249,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) } + test("test DATE types") { + val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() + val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. From b69c4f9b2e8544f1b178db2aefbcaa166f76cb7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:24:22 -0700 Subject: [PATCH 033/110] Disable flaky test: ReceiverSuite "block generator throttling". --- .../test/scala/org/apache/spark/streaming/ReceiverSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e7aee6eadbfc7..b84129fd70dd4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -155,7 +155,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(recordedData.toSet === generatedData.toSet) } - test("block generator throttling") { + ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 val maxRate = 1001 From 1b85e08509a0e19dc35b6ab869977254156cdaf1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:35:12 -0700 Subject: [PATCH 034/110] [MLlib] UnaryTransformer nullability should not depend on PrimitiveType. Author: Reynold Xin Closes #5644 from rxin/mllib-nullable and squashes the following commits: a727e5b [Reynold Xin] [MLlib] UnaryTransformer nullability should not depend on primitive types. --- mllib/src/main/scala/org/apache/spark/ml/Transformer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 7fb87fe452ee6..0acda71ec6045 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -94,7 +94,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") } val outputFields = schema.fields :+ - StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) + StructField(map(outputCol), outputDataType, nullable = false) StructType(outputFields) } From d20686066e978dd12e618e3978f109f05bc412fe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:35:42 -0700 Subject: [PATCH 035/110] [SPARK-7066][MLlib] VectorAssembler should use NumericType not NativeType. Author: Reynold Xin Closes #5642 from rxin/mllib-native-type and squashes the following commits: e23af5b [Reynold Xin] Remove StringType 7cbb205 [Reynold Xin] [SPARK-7066][MLlib] VectorAssembler should use NumericType and StringType, not NativeType. --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 5 +++-- .../main/scala/org/apache/spark/sql/types/dataTypes.scala | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e567e069e7c0b..fd16d3d6c268b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -55,7 +55,8 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { schema(c).dataType match { case DoubleType => UnresolvedAttribute(c) case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case _: NumericType => + Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() } } dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) @@ -67,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { - case _: NativeType => + case _: NumericType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 7cd7bd1914c95..ddf9d664c6826 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -299,7 +299,7 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[spark] object NativeType { +protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -327,7 +327,7 @@ protected[sql] object PrimitiveType { } } -protected[spark] abstract class NativeType extends DataType { +protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] From 03e85b4a11899f37424cd6e1f8d71f1d704c90bb Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 22 Apr 2015 21:42:09 -0700 Subject: [PATCH 036/110] [SPARK-7046] Remove InputMetrics from BlockResult This is a code cleanup. The BlockResult class originally contained an InputMetrics object so that InputMetrics could directly be used as the InputMetrics for the whole task. Now we copy the fields out of here, and the presence of this object is confusing because it's only a partial input metrics (it doesn't include the records read). Because this object is no longer useful (and is confusing), it should be removed. Author: Kay Ousterhout Closes #5627 from kayousterhout/SPARK-7046 and squashes the following commits: bf64bbe [Kay Ousterhout] Import fix a08ca19 [Kay Ousterhout] [SPARK-7046] Remove InputMetrics from BlockResult --- .../main/scala/org/apache/spark/CacheManager.scala | 5 ++--- .../org/apache/spark/storage/BlockManager.scala | 9 +++------ .../org/apache/spark/storage/BlockManagerSuite.scala | 12 ++++++------ 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a96d754744a05..4d20c7369376e 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,10 +44,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics - .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.incBytesRead(inputMetrics.bytesRead) + .getInputMetricsForReadMethod(blockResult.readMethod) + existingMetrics.incBytesRead(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] new InterruptibleIterator[T](context, iter) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 145a9c1ae3391..55718e584c195 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import scala.util.Random import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor._ +import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -50,11 +50,8 @@ private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], - readMethod: DataReadMethod.Value, - bytes: Long) { - val inputMetrics = new InputMetrics(readMethod) - inputMetrics.incBytesRead(bytes) -} + val readMethod: DataReadMethod.Value, + val bytes: Long) /** * Manager running on every node (driver and executors) which provides interfaces for putting and diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 545722b050ee8..7d82a7c66ad1a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -428,19 +428,19 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) - assert(list1Get.get.inputMetrics.bytesRead === list1SizeEstimate) - assert(list1Get.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list1Get.get.bytes === list1SizeEstimate) + assert(list1Get.get.readMethod === DataReadMethod.Memory) val list2MemoryGet = store.get("list2memory") assert(list2MemoryGet.isDefined, "list2memory expected to be in store") assert(list2MemoryGet.get.data.size === 3) - assert(list2MemoryGet.get.inputMetrics.bytesRead === list2SizeEstimate) - assert(list2MemoryGet.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list2MemoryGet.get.bytes === list2SizeEstimate) + assert(list2MemoryGet.get.readMethod === DataReadMethod.Memory) val list2DiskGet = store.get("list2disk") assert(list2DiskGet.isDefined, "list2memory expected to be in store") assert(list2DiskGet.get.data.size === 3) // We don't know the exact size of the data on disk, but it should certainly be > 0. - assert(list2DiskGet.get.inputMetrics.bytesRead > 0) - assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk) + assert(list2DiskGet.get.bytes > 0) + assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } test("in-memory LRU storage") { From d9e70f331fc3999d615ede49fc69a993dc65f272 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Apr 2015 22:18:56 -0700 Subject: [PATCH 037/110] [HOTFIX][SQL] Fix broken cached test Added in #5475. Pointed as broken in #5639. /cc marmbrus Author: Liang-Chi Hsieh Closes #5640 from viirya/fix_cached_test and squashes the following commits: c0cf69a [Liang-Chi Hsieh] Fix broken cached test. --- .../apache/spark/sql/CachedTableSuite.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 01e3b8671071e..0772e5e187425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -300,19 +300,26 @@ class CachedTableSuite extends QueryTest { } test("Clear accumulators when uncacheTable to prevent memory leaking") { - val accsSize = Accumulators.originals.size - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + cacheTable("t1") + cacheTable("t2") + assert((accsSize + 2) == Accumulators.originals.size) + } + sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - uncacheTable("t1") - uncacheTable("t2") - assert(accsSize >= Accumulators.originals.size) + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + uncacheTable("t1") + uncacheTable("t2") + assert((accsSize - 2) == Accumulators.originals.size) + } } } From 2d33323cadbf58dd1d05ffff998d18cad6a896cd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 23:54:48 -0700 Subject: [PATCH 038/110] [MLlib] Add support for BooleanType to VectorAssembler. Author: Reynold Xin Closes #5648 from rxin/vectorAssembler-boolean and squashes the following commits: 1bf3d40 [Reynold Xin] [MLlib] Add support for BooleanType to VectorAssembler. --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index fd16d3d6c268b..7b2a451ca5ee5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -55,7 +55,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { schema(c).dataType match { case DoubleType => UnresolvedAttribute(c) case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType => + case _: NumericType | BooleanType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() } } @@ -68,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { - case _: NumericType => + case _: NumericType | BooleanType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") From 29163c520087e89ca322521db1dd8656d86a6f0e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 23:55:20 -0700 Subject: [PATCH 039/110] [SPARK-7068][SQL] Remove PrimitiveType Author: Reynold Xin Closes #5646 from rxin/remove-primitive-type and squashes the following commits: 01b673d [Reynold Xin] [SPARK-7068][SQL] Remove PrimitiveType --- .../apache/spark/sql/types/dataTypes.scala | 70 ++++++++----------- .../spark/sql/parquet/ParquetConverter.scala | 11 +-- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../spark/sql/parquet/ParquetTypes.scala | 6 +- .../apache/spark/sql/parquet/newParquet.scala | 13 ++-- 5 files changed, 48 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index ddf9d664c6826..42e26e05996dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -41,6 +41,21 @@ import org.apache.spark.util.Utils object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + private val nonDecimalNameToType = { + (Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all) + .map(t => t.typeName -> t).toMap + } + + /** Given the string representation of a type, return its DataType */ + private def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } + private object JSortedObject { def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { case JObject(seq) => Some(seq.toList.sortBy(_._1)) @@ -51,7 +66,7 @@ object DataType { // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. private def parseDataType(json: JValue): DataType = json match { case JString(name) => - PrimitiveType.nameToType(name) + nameToType(name) case JSortedObject( ("containsNull", JBool(n)), @@ -190,13 +205,11 @@ object DataType { equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) case (StructType(leftFields), StructType(rightFields)) => - leftFields.size == rightFields.size && - leftFields.zip(rightFields) - .forall{ - case (left, right) => - left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType) - } - case (left, right) => left == right + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r } } @@ -225,12 +238,11 @@ object DataType { equalsIgnoreCompatibleNullability(fromValue, toValue) case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - fromField.name == toField.name && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) } case (fromDataType, toDataType) => fromDataType == toDataType @@ -256,8 +268,6 @@ abstract class DataType { /** The default size of a value of this data type. */ def defaultSize: Int - def isPrimitive: Boolean = false - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase private[sql] def jsonValue: JValue = typeName @@ -307,26 +317,6 @@ protected[sql] object NativeType { } -protected[sql] trait PrimitiveType extends DataType { - override def isPrimitive: Boolean = true -} - - -protected[sql] object PrimitiveType { - private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all - private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap - - /** Given the string representation of a type, return its DataType */ - private[sql] def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r - name match { - case "decimal" => DecimalType.Unlimited - case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) - } - } -} - protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -346,7 +336,7 @@ protected[sql] abstract class NativeType extends DataType { * @group dataType */ @DeveloperApi -class StringType private() extends NativeType with PrimitiveType { +class StringType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -373,7 +363,7 @@ case object StringType extends StringType * @group dataType */ @DeveloperApi -class BinaryType private() extends NativeType with PrimitiveType { +class BinaryType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -407,7 +397,7 @@ case object BinaryType extends BinaryType *@group dataType */ @DeveloperApi -class BooleanType private() extends NativeType with PrimitiveType { +class BooleanType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -492,7 +482,7 @@ case object DateType extends DateType * * @group dataType */ -abstract class NumericType extends NativeType with PrimitiveType { +abstract class NumericType extends NativeType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index bc108e37dfb0f..116424539da11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -146,7 +146,8 @@ private[sql] object CatalystConverter { } } // All other primitive types use the default converter - case ctype: PrimitiveType => { // note: need the type tag here! + case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { + // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) } case _ => throw new RuntimeException( @@ -324,9 +325,9 @@ private[parquet] class CatalystGroupConverter( override def start(): Unit = { current = ArrayBuffer.fill(size)(null) - converters.foreach { - converter => if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converters.foreach { converter => + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer() } } } @@ -612,7 +613,7 @@ private[parquet] class CatalystArrayConverter( override def start(): Unit = { if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converter.asInstanceOf[CatalystConverter].clearBuffer() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1c868da23e060..a938b77578686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -268,7 +268,7 @@ private[sql] case class InsertIntoParquetTable( val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val writeSupport = - if (child.output.map(_.dataType).forall(_.isPrimitive)) { + if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 60e1bec4db8e5..1dc819b5d7b9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -48,8 +48,10 @@ private[parquet] case class ParquetTypeInfo( length: Option[Int] = None) private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = - classOf[PrimitiveType] isAssignableFrom ctype.getClass + def isPrimitiveType(ctype: DataType): Boolean = ctype match { + case _: NumericType | BooleanType | StringType | BinaryType => true + case _: DataType => false + } def toPrimitiveDataType( parquetType: ParquetPrimitiveType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 88466f52bd4e9..85e60733bc57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -634,12 +634,13 @@ private[sql] case class ParquetRelation2( // before calling execute(). val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } + val writeSupport = + if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } ParquetOutputFormat.setWriteSupportClass(job, writeSupport) From f60bece14f98450b4a71b00d7b58525f06e1f9ed Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 01:43:40 -0700 Subject: [PATCH 040/110] [SPARK-7069][SQL] Rename NativeType -> AtomicType. Also renamed JvmType to InternalType. Author: Reynold Xin Closes #5651 from rxin/native-to-atomic-type and squashes the following commits: cbd4028 [Reynold Xin] [SPARK-7069][SQL] Rename NativeType -> AtomicType. --- .../spark/sql/catalyst/ScalaReflection.scala | 24 ++-- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 18 ++- .../codegen/GenerateProjection.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 10 +- .../spark/sql/catalyst/expressions/rows.scala | 6 +- .../apache/spark/sql/types/dataTypes.scala | 114 +++++++++--------- .../spark/sql/columnar/ColumnAccessor.scala | 2 +- .../spark/sql/columnar/ColumnBuilder.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 6 +- .../CompressibleColumnAccessor.scala | 4 +- .../CompressibleColumnBuilder.scala | 4 +- .../compression/CompressionScheme.scala | 10 +- .../compression/compressionSchemes.scala | 42 +++---- .../org/apache/spark/sql/json/JsonRDD.scala | 6 +- .../spark/sql/parquet/ParquetConverter.scala | 12 +- .../sql/parquet/ParquetTableSupport.scala | 2 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 6 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 8 +- .../sql/columnar/ColumnarTestUtils.scala | 6 +- .../compression/DictionaryEncodingSuite.scala | 4 +- .../compression/IntegralDeltaSuite.scala | 6 +- .../compression/RunLengthEncodingSuite.scala | 4 +- .../TestCompressibleColumnBuilder.scala | 6 +- 24 files changed, 159 insertions(+), 153 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d9521953cad73..c52965507c715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp - import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -110,7 +108,7 @@ trait ScalaReflection { StructField(p.name.toString, dataType, nullable) }), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) @@ -136,20 +134,20 @@ trait ScalaReflection { def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. - case obj: BooleanType.JvmType => BooleanType - case obj: BinaryType.JvmType => BinaryType + case obj: Boolean => BooleanType + case obj: Array[Byte] => BinaryType case obj: String => StringType - case obj: StringType.JvmType => StringType - case obj: ByteType.JvmType => ByteType - case obj: ShortType.JvmType => ShortType - case obj: IntegerType.JvmType => IntegerType - case obj: LongType.JvmType => LongType - case obj: FloatType.JvmType => FloatType - case obj: DoubleType.JvmType => DoubleType + case obj: UTF8String => StringType + case obj: Byte => ByteType + case obj: Short => ShortType + case obj: Int => IntegerType + case obj: Long => LongType + case obj: Float => FloatType + case obj: Double => DoubleType case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited - case obj: TimestampType.JvmType => TimestampType + case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a // Catalyst data type. A user should provide his/her specific rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 566b34f7c3a6a..140ccd8d3796f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } lazy val ordering = left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } @@ -391,7 +391,7 @@ case class MinOf(left: Expression, right: Expression) extends Expression { } lazy val ordering = left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index eeffedb558c1b..cbe520347385d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -623,7 +623,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } } @@ -635,7 +635,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin value: TermName) = { dataType match { case StringType => q"$destinationRow.update($ordinal, $value)" - case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case dt: DataType if isNativeType(dt) => + q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } } @@ -675,7 +676,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def termForType(dt: DataType) = dt match { - case n: NativeType => n.tag + case n: AtomicType => n.tag case _ => typeTag[Any] } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + protected val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d491babc2bff0..584f938445c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -109,7 +109,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" } - val specificAccessorFunctions = NativeType.all.map { dataType => + val specificAccessorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { // getString() is not used by expressions case (e, i) if e.dataType == dataType && dataType != StringType => @@ -135,7 +135,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - val specificMutatorFunctions = NativeType.all.map { dataType => + val specificMutatorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { // setString() is not used by expressions case (e, i) if e.dataType == dataType && dataType != StringType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 46522eb9c1264..9cb00cb2732ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -211,7 +211,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -240,7 +240,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -269,7 +269,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -298,7 +298,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 981373477a4bc..5fd892c42e69c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -227,9 +227,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return if (order.direction == Ascending) 1 else -1 } else { val comparison = order.dataType match { - case n: NativeType if order.direction == Ascending => + case n: AtomicType if order.direction == Ascending => n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case n: NativeType if order.direction == Descending => + case n: AtomicType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => sys.error(s"Type $other does not support ordered operations") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 42e26e05996dd..87c7b7599366a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -42,7 +42,8 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { - (Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all) + Seq(NullType, DateType, TimestampType, BinaryType, + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) .map(t => t.typeName -> t).toMap } @@ -309,22 +310,17 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[sql] object NativeType { - val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - - def unapply(dt: DataType): Boolean = all.contains(dt) -} - - -protected[sql] abstract class NativeType extends DataType { - private[sql] type JvmType - @transient private[sql] val tag: TypeTag[JvmType] - private[sql] val ordering: Ordering[JvmType] +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] @transient private[sql] val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) } } @@ -336,13 +332,13 @@ protected[sql] abstract class NativeType extends DataType { * @group dataType */ @DeveloperApi -class StringType private() extends NativeType { +class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = UTF8String - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the StringType is 4096 bytes. @@ -363,13 +359,13 @@ case object StringType extends StringType * @group dataType */ @DeveloperApi -class BinaryType private() extends NativeType { +class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Array[Byte] - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = new Ordering[JvmType] { + private[sql] type InternalType = Array[Byte] + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = new Ordering[InternalType] { def compare(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val res = x(i).compareTo(y(i)) @@ -397,13 +393,13 @@ case object BinaryType extends BinaryType *@group dataType */ @DeveloperApi -class BooleanType private() extends NativeType { +class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Boolean - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] type InternalType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the BooleanType is 1 byte. @@ -424,15 +420,15 @@ case object BooleanType extends BooleanType * @group dataType */ @DeveloperApi -class TimestampType private() extends NativeType { +class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Timestamp + private[sql] type InternalType = Timestamp - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[InternalType] { def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) } @@ -455,15 +451,15 @@ case object TimestampType extends TimestampType * @group dataType */ @DeveloperApi -class DateType private() extends NativeType { +class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Int + private[sql] type InternalType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the DateType is 4 bytes. @@ -482,13 +478,13 @@ case object DateType extends DateType * * @group dataType */ -abstract class NumericType extends NativeType { +abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[InternalType] } @@ -507,7 +503,7 @@ protected[sql] object IntegralType { protected[sql] sealed abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[JvmType] + private[sql] val integral: Integral[InternalType] } @@ -522,11 +518,11 @@ class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Long - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the LongType is 8 bytes. @@ -552,11 +548,11 @@ class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the IntegerType is 4 bytes. @@ -582,11 +578,11 @@ class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Short - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the ShortType is 2 bytes. @@ -612,11 +608,11 @@ class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Byte - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the ByteType is 1 byte. @@ -641,8 +637,8 @@ protected[sql] object FractionalType { protected[sql] sealed abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[JvmType] - private[sql] val asIntegral: Integral[JvmType] + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] } @@ -665,8 +661,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** No-arg constructor for kryo. */ protected def this() = this(null) - private[sql] type JvmType = Decimal - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = Decimal.DecimalIsFractional private[sql] val fractional = Decimal.DecimalIsFractional private[sql] val ordering = Decimal.DecimalIsFractional @@ -743,11 +739,11 @@ class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Double - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] private[sql] val asIntegral = DoubleAsIfIntegral /** @@ -772,11 +768,11 @@ class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Float - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index f615fb33a7c35..64449b2659b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -61,7 +61,7 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def underlyingBuffer = buffer } -private[sql] abstract class NativeColumnAccessor[T <: NativeType]( +private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 00ed70430b84d..aa10af400c815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -84,10 +84,10 @@ private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: NativeType]( +private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) + extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 1b9e0df2dcb5e..20be5ca9d0046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -101,16 +101,16 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] abstract class NativeColumnType[T <: NativeType]( +private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, typeId: Int, defaultSize: Int) - extends ColumnType[T, T#JvmType](typeId, defaultSize) { + extends ColumnType[T, T#InternalType](typeId, defaultSize) { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. */ - def scalaTag: TypeTag[dataType.JvmType] = dataType.tag + def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index d0b602a834dfe..cb205defbb1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { +private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index b9cfc5df550d1..8e2a1af6dae78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType /** * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.NativeType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: NativeType] +private[sql] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 879d29bcfa6f6..17c2d9b111188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -22,9 +22,9 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: NativeType] { +private[sql] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -38,7 +38,7 @@ private[sql] trait Encoder[T <: NativeType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] { +private[sql] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean @@ -49,9 +49,9 @@ private[sql] trait CompressionScheme { def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] + def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] - def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] + def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } private[sql] trait WithCompressionSchemes { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8727d71c48bb7..534ae90ddbc8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -35,16 +35,16 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]): Boolean = true - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType]( + override def decoder[T <: AtomicType]( buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize: Int = 0 override def compressedSize: Int = 0 @@ -56,7 +56,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { override def next(row: MutableRow, ordinal: Int): Unit = { @@ -70,11 +70,11 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType]( + override def decoder[T <: AtomicType]( buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } @@ -84,7 +84,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 @@ -152,12 +152,12 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private var run = 0 private var valueCount = 0 - private var currentValue: T#JvmType = _ + private var currentValue: T#InternalType = _ override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { @@ -181,12 +181,12 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // 32K unique values allowed val MAX_DICT_SIZE = Short.MaxValue - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : Decoder[T] = { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } @@ -195,7 +195,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -208,7 +208,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { private var count = 0 // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. - private var values = new mutable.ArrayBuffer[T#JvmType](1024) + private var values = new mutable.ArrayBuffer[T#InternalType](1024) // The dictionary that maps a value to the encoded short integer. private val dictionary = mutable.HashMap.empty[Any, Short] @@ -268,14 +268,14 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private val dictionary = { // TODO Can we clean up this mess? Maybe move this to `DataType`? implicit val classTag = { val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) + ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) } Array.fill(buffer.getInt()) { @@ -296,12 +296,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme { val BITS_PER_LONG = 64 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } @@ -384,12 +384,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private[sql] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } @@ -464,12 +464,12 @@ private[sql] case object IntDelta extends CompressionScheme { private[sql] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 29de7401dda71..6e94e7056eb0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -183,7 +183,7 @@ private[sql] object JsonRDD extends Logging { private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { // For Integer values, use LongType by default. val useLongType: PartialFunction[Any, DataType] = { - case value: IntegerType.JvmType => LongType + case value: IntegerType.InternalType => LongType } useLongType orElse ScalaReflection.typeOfObject orElse { @@ -411,11 +411,11 @@ private[sql] object JsonRDD extends Logging { desiredType match { case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case IntegerType => value.asInstanceOf[IntegerType.InternalType] case LongType => toLong(value) case DoubleType => toDouble(value) case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case BooleanType => value.asInstanceOf[BooleanType.InternalType] case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 116424539da11..36cb5e03bbca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -90,7 +90,7 @@ private[sql] object CatalystConverter { createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) } // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType, false) => { + case ArrayType(elementType: AtomicType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields @@ -118,19 +118,19 @@ private[sql] object CatalystConverter { case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) } } case ByteType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) } } case DateType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType]) + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) } } case d: DecimalType => { @@ -637,13 +637,13 @@ private[parquet] class CatalystArrayConverter( * @param capacity The (initial) capacity of the buffer */ private[parquet] class CatalystNativeArrayConverter( - val elementType: NativeType, + val elementType: AtomicType, val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) extends CatalystConverter { - type NativeType = elementType.JvmType + type NativeType = elementType.InternalType private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e05a4c20b0d41..c45c431438efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -189,7 +189,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[NativeType], value) + case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index fec487f1d2c82..7cefcf44061ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,7 +34,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) - def testColumnStats[T <: NativeType, U <: ColumnStats]( + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], initialStatistics: Row): Unit = { @@ -55,8 +55,8 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) - val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index b48bed1871c50..1e105e259dce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -196,12 +196,12 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - def testNativeColumnType[T <: NativeType]( + def testNativeColumnType[T <: AtomicType]( columnType: NativeColumnType[T], - putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType): Unit = { + putter: (ByteBuffer, T#InternalType) => Unit, + getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#JvmType](columnType, putter, getter) + testColumnType[T, T#InternalType](columnType, putter, getter) } def testColumnType[T <: DataType, JvmType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f76314b9dab5e..75d993e563e06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -91,9 +91,9 @@ object ColumnarTestUtils { row } - def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = { + count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index c82d9799359c7..64b70552eb047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) - def testDictionaryEncoding[T <: NativeType]( + def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 88011631ee4e3..bfd99f143bedc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -33,7 +33,7 @@ class IntegralDeltaSuite extends FunSuite { columnType: NativeColumnType[I], scheme: CompressionScheme) { - def skeleton(input: Seq[I#JvmType]) { + def skeleton(input: Seq[I#InternalType]) { // ------------- // Tests encoder // ------------- @@ -120,13 +120,13 @@ class IntegralDeltaSuite extends FunSuite { case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } test(s"$scheme: long random series") { // Have to workaround with `Any` since no `ClassTag[I#JvmType]` available here. val input = Array.fill[Any](10000)(makeRandomValue(columnType)) - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 08df1db375097..fde7a4595be0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new LongColumnStats, LONG) testRunLengthEncoding(new StringColumnStats, STRING) - def testRunLengthEncoding[T <: NativeType]( + def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index fc8ff3b41d0e6..5268dfe0aa03e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class TestCompressibleColumnBuilder[T <: NativeType]( +class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) @@ -32,7 +32,7 @@ class TestCompressibleColumnBuilder[T <: NativeType]( } object TestCompressibleColumnBuilder { - def apply[T <: NativeType]( + def apply[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T], scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { From a7d65d38f934c5c751ba32aa7ab648c6d16044ab Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 23 Apr 2015 16:45:26 +0530 Subject: [PATCH 041/110] [HOTFIX] [SQL] Fix compilation for scala 2.11. Author: Prashant Sharma Closes #5652 from ScrapCodes/hf/compilation-fix-scala-2.11 and squashes the following commits: 819ff06 [Prashant Sharma] [HOTFIX] Fix compilation for scala 2.11. --- .../test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index fc3ed4a708d46..e02c84872c628 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -162,7 +162,7 @@ public void testCreateDataFrameFromJavaBeans() { Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { From 975f53e4f978759db7639cd08498ad8cd0ae2a56 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Thu, 23 Apr 2015 10:33:13 -0700 Subject: [PATCH 042/110] [minor][streaming]fixed scala string interpolation error Author: Prabeesh K Closes #5653 from prabeesh/fix and squashes the following commits: 9d7a9f5 [Prabeesh K] fixed scala string interpolation error --- .../org/apache/spark/examples/streaming/MQTTWordCount.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index f40caad322f59..85b9a54b40baf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -56,7 +56,7 @@ object MQTTPublisher { while (true) { try { msgtopic.publish(message) - println(s"Published data. topic: {msgtopic.getName()}; Message: {message}") + println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => Thread.sleep(10) From cc48e6387abdd909921cb58e0588cdf226556bcd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 23 Apr 2015 10:35:22 -0700 Subject: [PATCH 043/110] [SPARK-7044] [SQL] Fix the deadlock in script transformation Author: Cheng Hao Closes #5625 from chenghao-intel/transform and squashes the following commits: 5ec1dd2 [Cheng Hao] fix the deadlock issue in ScriptTransform --- .../hive/execution/ScriptTransformation.scala | 33 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 8 +++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index cab0fdd35723a..3eddda3b28c66 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -145,20 +145,29 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + // Put the write(output to the pipeline) into a single thread + // and keep the collector as remain in the main thread. + // otherwise it will causes deadlock if the data size greater than + // the pipeline / buffer capacity. + new Thread(new Runnable() { + override def run(): Unit = { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } } + outputStream.close() } - outputStream.close() + }).start() + iterator } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47b4cb9ca61ff..4f8d0ac0e7656 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -561,4 +561,12 @@ class SQLQuerySuite extends QueryTest { sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } + + test("test script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) + } } From 534f2a43625fbf1a3a65d09550a19875cd1dce43 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 23 Apr 2015 11:29:34 -0700 Subject: [PATCH 044/110] [SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext Currently if you want to create a StreamingContext from checkpoint information, the system will create a new SparkContext. This prevent StreamingContext to be recreated from checkpoints in managed environments where SparkContext is precreated. The solution in this PR: Introduce the following methods on StreamingContext 1. `new StreamingContext(checkpointDirectory, sparkContext)` Recreate StreamingContext from checkpoint using the provided SparkContext 2. `StreamingContext.getOrCreate(checkpointDirectory, sparkContext, createFunction: SparkContext => StreamingContext)` If checkpoint file exists, then recreate StreamingContext using the provided SparkContext (that is, call 1.), else create StreamingContext using the provided createFunction TODO: the corresponding Java and Python API has to be added as well. Author: Tathagata Das Closes #5428 from tdas/SPARK-6752 and squashes the following commits: 94db63c [Tathagata Das] Fix long line. 524f519 [Tathagata Das] Many changes based on PR comments. eabd092 [Tathagata Das] Added Function0, Java API and unit tests for StreamingContext.getOrCreate 36a7823 [Tathagata Das] Minor changes. 204814e [Tathagata Das] Added StreamingContext.getOrCreate with existing SparkContext --- .../spark/api/java/function/Function0.java | 27 +++ .../apache/spark/streaming/Checkpoint.scala | 26 ++- .../spark/streaming/StreamingContext.scala | 85 ++++++++-- .../api/java/JavaStreamingContext.scala | 119 ++++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 145 ++++++++++++---- .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ++++++++++++++++++ 7 files changed, 503 insertions(+), 61 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java new file mode 100644 index 0000000000000..38e410c5debe6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 0a50485118588..7bfae253c3a0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,7 +77,8 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -85,6 +86,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -160,7 +162,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -234,15 +236,24 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -282,7 +293,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f57f295874645..90c8b47aebce0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,6 +107,19 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -115,10 +128,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { + if (sc_ != null) { + sc_ + } else if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - sc_ + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -129,7 +144,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -174,7 +189,9 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) + assert(env != null) + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -621,19 +638,59 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext + ): StreamingContext = { + getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext, + createOnError: Boolean + ): StreamingContext = { + val checkpointOption = CheckpointReader.read( + checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) + checkpointOption.map(new StreamingContext(sparkContext, _, null)) + .getOrElse(creatingFunc(sparkContext)) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4095a7cc84946..572d7d8e8753d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -655,6 +656,7 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -676,6 +678,7 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -700,6 +703,7 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -712,6 +716,117 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 90340753a4eed..cb2e8380b4933 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,10 +22,12 @@ import java.nio.charset.Charset; import java.util.*; +import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -929,7 +932,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -987,12 +990,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1123,7 +1126,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1144,14 +1147,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1249,17 +1252,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1292,17 +1295,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1328,7 +1331,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1707,6 +1710,74 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final MutableBoolean newContextCreated = new MutableBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.setValue(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + + // Function to create JavaStreamingContext using existing JavaSparkContext + // without any output operations (used to detect the new context) + Function creatingFunc2 = + new Function() { + public JavaStreamingContext call(JavaSparkContext context) { + newContextCreated.setValue(true); + return new JavaStreamingContext(context, Seconds.apply(1)); + } + }; + + JavaSparkContext sc = new JavaSparkContext(conf); + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 54c30440a6e8d..6b0a3f91d4d06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 58353a5f97c8a..4f193322ad33e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.io.File import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("getOrCreate with existing SparkContext") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(sparkContext: SparkContext): StreamingContext = { + newContextCreated = true + new StreamingContext(sparkContext, batchDuration) + } + + // Call ssc.stop(stopSparkContext = false) after a body of cody + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val checkpointPath = createValidCheckpoint() + + // StreamingContext.getOrCreate should recover context with checkpoint path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(!ssc.conf.contains("someKey"), + "recovered StreamingContext unexpectedly has old config") + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("someKey", "someValue") + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) From c1213e6a92e126ad886d9804cedaf6db3618e602 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Thu, 23 Apr 2015 12:00:23 -0700 Subject: [PATCH 045/110] [SPARK-7055][SQL]Use correct ClassLoader for JDBC Driver in JDBCRDD.getConnector Author: Vinod K C Closes #5633 from vinodkc/use_correct_classloader_driverload and squashes the following commits: 73c5380 [Vinod K C] Use correct ClassLoader for JDBC Driver --- .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index b975191d41963..f326510042122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils private[sql] object JDBCRDD extends Logging { /** @@ -152,7 +153,7 @@ private[sql] object JDBCRDD extends Logging { def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { - if (driver != null) Class.forName(driver) + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) } catch { case e: ClassNotFoundException => { logWarning(s"Couldn't find class $driver", e); From 6afde2c7810c363083d0a699b1de02b54c13e6a9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Apr 2015 13:19:03 -0700 Subject: [PATCH 046/110] [SPARK-7058] Include RDD deserialization time in "task deserialization time" metric The web UI's "task deserialization time" metric is slightly misleading because it does not capture the time taken to deserialize the broadcasted RDD. Author: Josh Rosen Closes #5635 from JoshRosen/SPARK-7058 and squashes the following commits: ed90f75 [Josh Rosen] Update UI tooltip a3743b4 [Josh Rosen] Update comments. 4f52910 [Josh Rosen] Roll back whitespace change e9cf9f4 [Josh Rosen] Remove unused variable 9f32e55 [Josh Rosen] Expose executorDeserializeTime on Task instead of pushing runtime calculation into Task. 21f5b47 [Josh Rosen] Don't double-count the broadcast deserialization time in task runtime 1752f0e [Josh Rosen] [SPARK-7058] Incorporate RDD deserialization time in task deserialization time metric --- .../main/scala/org/apache/spark/executor/Executor.scala | 8 ++++++-- .../scala/org/apache/spark/scheduler/ResultTask.scala | 2 ++ .../scala/org/apache/spark/scheduler/ShuffleMapTask.scala | 2 ++ core/src/main/scala/org/apache/spark/scheduler/Task.scala | 7 +++++++ core/src/main/scala/org/apache/spark/ui/ToolTips.scala | 4 +++- 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5fc04df5d6a40..f57e215c3f2ed 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -220,8 +220,12 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.setExecutorDeserializeTime(taskStart - deserializeStartTime) - m.setExecutorRunTime(taskFinish - taskStart) + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + m.setExecutorDeserializeTime( + (taskStart - deserializeStartTime) + task.executorDeserializeTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e074ce6ebff0b..c9a124113961f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -53,9 +53,11 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 6c7d00069acb2..bd3dd23dfe1ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -56,9 +56,11 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 8b592867ee31d..b09b19e2ac9e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -87,11 +87,18 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // initialized when kill() is invoked. @volatile @transient private var _killed = false + protected var _executorDeserializeTime: Long = 0 + /** * Whether the task has been killed. */ def killed: Boolean = _killed + /** + * Returns the amount of time spent deserializing the RDD and function to be run. + */ + def executorDeserializeTime: Long = _executorDeserializeTime + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index cae6870c2ab20..24f3236456248 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,7 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + val TASK_DESERIALIZATION_TIME = + """Time spent deserializing the task closure on the executor, including the time to read the + broadcasted task.""" val SHUFFLE_READ_BLOCKED_TIME = "Time that the task spent blocked waiting for shuffle data to be read from remote machines." From 3e91cc273d281053618bfa032bc610e2cf8d8e78 Mon Sep 17 00:00:00 2001 From: wizz Date: Thu, 23 Apr 2015 14:00:07 -0700 Subject: [PATCH 047/110] [SPARK-7085][MLlib] Fix miniBatchFraction parameter in train method called with 4 arguments Author: wizz Closes #5658 from kuromatsu-nobuyuki/SPARK-7085 and squashes the following commits: 6ec2d21 [wizz] Fix miniBatchFraction parameter in train method called with 4 arguments --- .../org/apache/spark/mllib/regression/RidgeRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 8838ca8c14718..309f9af466457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -171,7 +171,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 0.01) + train(input, numIterations, stepSize, regParam, 1.0) } /** From baa83a9a6769c5e119438d65d7264dceb8d743d5 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 23 Apr 2015 17:20:17 -0400 Subject: [PATCH 048/110] [SPARK-6879] [HISTORYSERVER] check if app is completed before clean it up https://issues.apache.org/jira/browse/SPARK-6879 Use `applications` to replace `FileStatus`, and check if the app is completed before clean it up. If an exception was throwed, add it to `applications` to wait for the next loop. Author: WangTaoTheTonic Closes #5491 from WangTaoTheTonic/SPARK-6879 and squashes the following commits: 4a533eb [WangTaoTheTonic] treat ACE specially cb45105 [WangTaoTheTonic] rebase d4d5251 [WangTaoTheTonic] per Marcelo's comments d7455d8 [WangTaoTheTonic] slightly change when delete file b0abca5 [WangTaoTheTonic] use global var to store apps to clean 94adfe1 [WangTaoTheTonic] leave expired apps alone to be deleted 9872a9d [WangTaoTheTonic] use the right path fdef4d6 [WangTaoTheTonic] check if app is completed before clean it up --- .../deploy/history/FsHistoryProvider.scala | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9847d5944a390..a94ebf6e53750 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -35,7 +35,6 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SecurityManager, SparkConf} - /** * A class that provides application history from event logs stored in the file system. * This provider checks for new finished applications in the background periodically and @@ -76,6 +75,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() + // List of applications to be deleted by event log cleaner. + private var appsToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + // Constants used to parse Spark 1.0.0 log directories. private[history] val LOG_PREFIX = "EVENT_LOG_" private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" @@ -266,34 +268,40 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def cleanLogs(): Unit = { try { - val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) - .getOrElse(Seq[FileStatus]()) val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 val now = System.currentTimeMillis() val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + // Scan all logs from the log directory. + // Only completed applications older than the specified max age will be deleted. applications.values.foreach { info => - if (now - info.lastUpdated <= maxAge) { + if (now - info.lastUpdated <= maxAge || !info.completed) { appsToRetain += (info.id -> info) + } else { + appsToClean += info } } applications = appsToRetain - // Scan all logs from the log directory. - // Only directories older than the specified max age will be deleted - statusList.foreach { dir => + val leftToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + appsToClean.foreach { info => try { - if (now - dir.getModificationTime() > maxAge) { - // if path is a directory and set to true, - // the directory is deleted else throws an exception - fs.delete(dir.getPath, true) + val path = new Path(logDir, info.logPath) + if (fs.exists(path)) { + fs.delete(path, true) } } catch { - case t: IOException => logError(s"IOException in cleaning logs of $dir", t) + case e: AccessControlException => + logInfo(s"No permission to delete ${info.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning logs of ${info.logPath}", t) + leftToClean += info } } + + appsToClean = leftToClean } catch { case t: Exception => logError("Exception in cleaning logs", t) } From 6d0749cae301ee4bf37632d657de48e75548a523 Mon Sep 17 00:00:00 2001 From: Tijo Thomas Date: Thu, 23 Apr 2015 17:23:15 -0400 Subject: [PATCH 049/110] [SPARK-7087] [BUILD] Fix path issue change version script Author: Tijo Thomas Closes #5656 from tijoparacka/FIX_PATHISSUE_CHANGE_VERSION_SCRIPT and squashes the following commits: ab4f4b1 [Tijo Thomas] removed whitespace 24478c9 [Tijo Thomas] modified to provide the spark base dir while searching for pom and also while changing the vesrion no 7b8e10b [Tijo Thomas] Modified for providing the base directories while finding the list of pom files and also while changing the version no --- dev/change-version-to-2.10.sh | 6 +++--- dev/change-version-to-2.11.sh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 15e0c73b4295e..c4adb1f96b7d3 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -18,9 +18,9 @@ # # Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X) - -find . -name 'pom.xml' | grep -v target \ +BASEDIR=$(dirname $0)/.. +find $BASEDIR -name 'pom.xml' | grep -v target \ | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {} # Also update in parent POM -sed -i -e '0,/2.112.102.112.10 in parent POM -sed -i -e '0,/2.102.112.102.11 Date: Thu, 23 Apr 2015 14:46:54 -0700 Subject: [PATCH 050/110] [SPARK-7070] [MLLIB] LDA.setBeta should call setTopicConcentration. jkbradley Author: Xiangrui Meng Closes #5649 from mengxr/SPARK-7070 and squashes the following commits: c66023c [Xiangrui Meng] setBeta should call setTopicConcentration --- .../scala/org/apache/spark/mllib/clustering/LDA.scala | 2 +- .../org/apache/spark/mllib/clustering/LDASuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 9d63a08e211bc..d006b39acb213 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -177,7 +177,7 @@ class LDA private ( def getBeta: Double = getTopicConcentration /** Alias for [[setTopicConcentration()]] */ - def setBeta(beta: Double): this.type = setBeta(beta) + def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 15de10fd13a19..cc747dabb9968 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -123,6 +123,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) } + + test("setter alias") { + val lda = new LDA().setAlpha(2.0).setBeta(3.0) + assert(lda.getAlpha === 2.0) + assert(lda.getDocConcentration === 2.0) + assert(lda.getBeta === 3.0) + assert(lda.getTopicConcentration === 3.0) + } } private[clustering] object LDASuite { From 6220d933e5ce4ba890f5d6a50a69b95d319dafb4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 14:48:19 -0700 Subject: [PATCH 051/110] [SQL] Break dataTypes.scala into multiple files. It was over 1000 lines of code, making it harder to find all the types. Only moved code around, and didn't change any. Author: Reynold Xin Closes #5670 from rxin/break-types and squashes the following commits: 8c59023 [Reynold Xin] Check in missing files. dcd5193 [Reynold Xin] [SQL] Break dataTypes.scala into multiple files. --- .../apache/spark/sql/types/ArrayType.scala | 74 + .../apache/spark/sql/types/BinaryType.scala | 63 + .../apache/spark/sql/types/BooleanType.scala | 51 + .../org/apache/spark/sql/types/ByteType.scala | 54 + .../org/apache/spark/sql/types/DataType.scala | 353 +++++ .../org/apache/spark/sql/types/DateType.scala | 54 + .../apache/spark/sql/types/DecimalType.scala | 110 ++ .../apache/spark/sql/types/DoubleType.scala | 53 + .../apache/spark/sql/types/FloatType.scala | 53 + .../apache/spark/sql/types/IntegerType.scala | 54 + .../org/apache/spark/sql/types/LongType.scala | 54 + .../org/apache/spark/sql/types/MapType.scala | 79 ++ .../org/apache/spark/sql/types/NullType.scala | 39 + .../apache/spark/sql/types/ShortType.scala | 53 + .../apache/spark/sql/types/StringType.scala | 50 + .../apache/spark/sql/types/StructField.scala | 54 + .../apache/spark/sql/types/StructType.scala | 263 ++++ .../spark/sql/types/TimestampType.scala | 57 + .../spark/sql/types/UserDefinedType.scala | 81 ++ .../apache/spark/sql/types/dataTypes.scala | 1224 ----------------- 20 files changed, 1649 insertions(+), 1224 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala new file mode 100644 index 0000000000000..b116163faccad --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) +} + + +/** + * :: DeveloperApi :: + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * Please use [[DataTypes.createArrayType()]] to create a specific instance. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + * + * @group dataType + */ +@DeveloperApi +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize + + override def simpleString: String = s"array<${elementType.simpleString}>" + + private[spark] override def asNullable: ArrayType = + ArrayType(elementType.asNullable, containsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala new file mode 100644 index 0000000000000..a581a9e9468ef --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Array[Byte]` values. + * Please use the singleton [[DataTypes.BinaryType]]. + * + * @group dataType + */ +@DeveloperApi +class BinaryType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + + private[sql] type InternalType = Array[Byte] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } + } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: BinaryType = this +} + + +case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala new file mode 100644 index 0000000000000..a7f228cefa57a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + *@group dataType + */ +@DeveloperApi +class BooleanType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 + + private[spark] override def asNullable: BooleanType = this +} + + +case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala new file mode 100644 index 0000000000000..4d8685796ec76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @group dataType + */ +@DeveloperApi +class ByteType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 + + override def simpleString: String = "tinyint" + + private[spark] override def asNullable: ByteType = this +} + +case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala new file mode 100644 index 0000000000000..e6bfcd9adfeb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.util.parsing.combinator.RegexParsers + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + + +/** + * :: DeveloperApi :: + * The base type of all Spark SQL data types. + * + * @group dataType + */ +@DeveloperApi +abstract class DataType { + /** Matches any expression that evaluates to this DataType */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } + + /** The default size of a value of this data type. */ + def defaultSize: Int + + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName + + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) + + def simpleString: String = typeName + + /** Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def sameType(other: DataType): Boolean = + DataType.equalsIgnoreNullability(this, other) + + /** Returns the same data type but set all nullability fields are true + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def asNullable: DataType +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + * + * @group dataType + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + + +/** Matcher for any expressions that evaluate to [[IntegralType]]s */ +private[sql] object IntegralType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + + +/** Matcher for any expressions that evaluate to [[FractionalType]]s */ +private[sql] object FractionalType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} + + +object DataType { + + def fromJson(json: String): DataType = parseDataType(parse(json)) + + @deprecated("Use DataType.fromJson instead", "1.2.0") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private val nonDecimalNameToType = { + Seq(NullType, DateType, TimestampType, BinaryType, + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + .map(t => t.typeName -> t).toMap + } + + /** Given the string representation of a type, return its DataType */ + private def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None + } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + } + + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("metadata", metadata: JObject), + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + // Support reading schema when 'metadata' is missing. + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } + } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r + } + } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + (tn || !fn) && + equalsIgnoreCompatibleNullability(fromKey, toKey) && + equalsIgnoreCompatibleNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala new file mode 100644 index 0000000000000..03f0644bc784c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Date` values. + * Please use the singleton [[DataTypes.DateType]]. + * + * @group dataType + */ +@DeveloperApi +class DateType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DateType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the DateType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: DateType = this +} + + +case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala new file mode 100644 index 0000000000000..0f8cecd28f7df --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression + + +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + + +/** + * :: DeveloperApi :: + * The data type representing `java.math.BigDecimal` values. + * A Decimal that might have fixed precision and scale, or unlimited values for these. + * + * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @group dataType + */ +@DeveloperApi +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + private[sql] type InternalType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + + def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + override def simpleString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal(10,0)" + } + + private[spark] override def asNullable: DecimalType = this +} + + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala new file mode 100644 index 0000000000000..66766623213c9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Fractional, Numeric} +import scala.math.Numeric.DoubleAsIfIntegral +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @group dataType + */ +@DeveloperApi +class DoubleType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 + + private[spark] override def asNullable: DoubleType = this +} + +case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala new file mode 100644 index 0000000000000..1d5a2f4f6f86c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Numeric.FloatAsIfIntegral +import scala.math.{Ordering, Fractional, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @group dataType + */ +@DeveloperApi +class FloatType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: FloatType = this +} + +case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala new file mode 100644 index 0000000000000..74e464c082873 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @group dataType + */ +@DeveloperApi +class IntegerType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 + + override def simpleString: String = "int" + + private[spark] override def asNullable: IntegerType = this +} + +case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala new file mode 100644 index 0000000000000..390675782e5fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @group dataType + */ +@DeveloperApi +class LongType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "LongType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 + + override def simpleString: String = "bigint" + + private[spark] override def asNullable: LongType = this +} + + +case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala new file mode 100644 index 0000000000000..cfdf493074415 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + + +/** + * :: DeveloperApi :: + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * + * Please use [[DataTypes.createMapType()]] to create a specific instance. + * + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + * + * @group dataType + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + + private[spark] override def asNullable: MapType = + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) +} + + +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala new file mode 100644 index 0000000000000..b64b07431fa96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @group dataType + */ +@DeveloperApi +class NullType private() extends DataType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "NullType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + override def defaultSize: Int = 1 + + private[spark] override def asNullable: NullType = this +} + +case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala new file mode 100644 index 0000000000000..73e9ec780b0af --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @group dataType + */ +@DeveloperApi +class ShortType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 + + override def simpleString: String = "smallint" + + private[spark] override def asNullable: ShortType = this +} + +case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala new file mode 100644 index 0000000000000..134ab0af4e0de --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @group dataType + */ +@DeveloperApi +class StringType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "StringType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: StringType = this +} + +case object StringType extends StringType + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala new file mode 100644 index 0000000000000..83570a5eaee61 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. + */ +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } + + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala new file mode 100644 index 0000000000000..d80ffca18ec9a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable.ArrayBuffer +import scala.math.max + +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} + + +/** + * :: DeveloperApi :: + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Any names without matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ +@DeveloperApi +case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + /** Returns all field names in an array. */ + def fieldNames: Array[String] = fields.map(_.name) + + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the + * original order of fields. Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (nonExistFields.nonEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + protected[sql] def toAttributes: Seq[AttributeReference] = + map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> map(_.jsonValue)) + + override def apply(fieldIndex: Int): StructField = fields(fieldIndex) + + override def length: Int = fields.length + + override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + + override def simpleString: String = { + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + s"struct<${fieldTypes.mkString(",")}>" + } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] + + private[spark] override def asNullable: StructType = { + val newFields = fields.map { + case StructField(name, dataType, nullable, metadata) => + StructField(name, dataType.asNullable, nullable = true, metadata) + } + + StructType(newFields) + } +} + + +object StructType { + + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + + def apply(fields: java.util.List[StructField]): StructType = { + StructType(fields.toArray.asInstanceOf[Array[StructField]]) + } + + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala new file mode 100644 index 0000000000000..aebabfc475925 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.Timestamp + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Timestamp` values. + * Please use the singleton [[DataTypes.TimestampType]]. + * + * @group dataType + */ +@DeveloperApi +class TimestampType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Timestamp + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + /** + * The default size of a value of the TimestampType is 12 bytes. + */ + override def defaultSize: Int = 12 + + private[spark] override def asNullable: TimestampType = this +} + +case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala new file mode 100644 index 0000000000000..6b20505c6009a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a `DataFrame` which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. + * The conversion via `deserialize` occurs when reading from a `DataFrame`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + /** + * For UDT, asNullable will not change the nullability of its internal sqlType and just returns + * itself. + */ + private[spark] override def asNullable: UserDefinedType[UserType] = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala deleted file mode 100644 index 87c7b7599366a..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ /dev/null @@ -1,1224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.sql.Timestamp - -import scala.collection.mutable.ArrayBuffer -import scala.math._ -import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} -import scala.util.parsing.combinator.RegexParsers - -import org.json4s._ -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.util.Utils - - -object DataType { - def fromJson(json: String): DataType = parseDataType(parse(json)) - - private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - .map(t => t.typeName -> t).toMap - } - - /** Given the string representation of a type, return its DataType */ - private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r - name match { - case "decimal" => DecimalType.Unlimited - case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) - } - } - - private object JSortedObject { - def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { - case JObject(seq) => Some(seq.toList.sortBy(_._1)) - case _ => None - } - } - - // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { - case JString(name) => - nameToType(name) - - case JSortedObject( - ("containsNull", JBool(n)), - ("elementType", t: JValue), - ("type", JString("array"))) => - ArrayType(parseDataType(t), n) - - case JSortedObject( - ("keyType", k: JValue), - ("type", JString("map")), - ("valueContainsNull", JBool(n)), - ("valueType", v: JValue)) => - MapType(parseDataType(k), parseDataType(v), n) - - case JSortedObject( - ("fields", JArray(fields)), - ("type", JString("struct"))) => - StructType(fields.map(parseStructField)) - - case JSortedObject( - ("class", JString(udtClass)), - ("pyClass", _), - ("sqlType", _), - ("type", JString("udt"))) => - Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] - } - - private def parseStructField(json: JValue): StructField = json match { - case JSortedObject( - ("metadata", metadata: JObject), - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) - // Support reading schema when 'metadata' is missing. - case JSortedObject( - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) - } - - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - } - - protected[types] def buildFormattedString( - dataType: DataType, - prefix: String, - builder: StringBuilder): Unit = { - dataType match { - case array: ArrayType => - array.buildFormattedString(prefix, builder) - case struct: StructType => - struct.buildFormattedString(prefix, builder) - case map: MapType => - map.buildFormattedString(prefix, builder) - case _ => - } - } - - /** - * Compares two types, ignoring nullability of ArrayType, MapType, StructType. - */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { - (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => - equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => - equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => - leftFields.length == rightFields.length && - leftFields.zip(rightFields).forall { case (l, r) => - l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) - } - case (l, r) => l == r - } - } - - /** - * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. - * - * Compatible nullability is defined as follows: - * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` - * if and only if `to.containsNull` is true, or both of `from.containsNull` and - * `to.containsNull` are false. - * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` - * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and - * `to.valueContainsNull` are false. - * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` - * if and only if for all every pair of fields, `to.nullable` is true, or both - * of `fromField.nullable` and `toField.nullable` are false. - */ - private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } - } -} - - -/** - * :: DeveloperApi :: - * The base type of all Spark SQL data types. - * - * @group dataType - */ -@DeveloperApi -abstract class DataType { - /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType == this => true - case _ => false - } - - /** The default size of a value of this data type. */ - def defaultSize: Int - - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase - - private[sql] def jsonValue: JValue = typeName - - def json: String = compact(render(jsonValue)) - - def prettyJson: String = pretty(render(jsonValue)) - - def simpleString: String = typeName - - /** Check if `this` and `other` are the same data type when ignoring nullability - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). - */ - private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) - - /** Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). - */ - private[spark] def asNullable: DataType -} - -/** - * :: DeveloperApi :: - * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType - */ -@DeveloperApi -class NullType private() extends DataType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "NullType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - override def defaultSize: Int = 1 - - private[spark] override def asNullable: NullType = this -} - -case object NullType extends NullType - - -/** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - */ -protected[sql] abstract class AtomicType extends DataType { - private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] - private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType - */ -@DeveloperApi -class StringType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "StringType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = UTF8String - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the StringType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - private[spark] override def asNullable: StringType = this -} - -case object StringType extends StringType - - -/** - * :: DeveloperApi :: - * The data type representing `Array[Byte]` values. - * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType - */ -@DeveloperApi -class BinaryType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Array[Byte] - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length - } - } - - /** - * The default size of a value of the BinaryType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - private[spark] override def asNullable: BinaryType = this -} - -case object BinaryType extends BinaryType - - -/** - * :: DeveloperApi :: - * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType - */ -@DeveloperApi -class BooleanType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Boolean - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the BooleanType is 1 byte. - */ - override def defaultSize: Int = 1 - - private[spark] override def asNullable: BooleanType = this -} - -case object BooleanType extends BooleanType - - -/** - * :: DeveloperApi :: - * The data type representing `java.sql.Timestamp` values. - * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType - */ -@DeveloperApi -class TimestampType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } - - /** - * The default size of a value of the TimestampType is 12 bytes. - */ - override def defaultSize: Int = 12 - - private[spark] override def asNullable: TimestampType = this -} - -case object TimestampType extends TimestampType - - -/** - * :: DeveloperApi :: - * The data type representing `java.sql.Date` values. - * Please use the singleton [[DataTypes.DateType]]. - * - * @group dataType - */ -@DeveloperApi -class DateType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "DateType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Int - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the DateType is 4 bytes. - */ - override def defaultSize: Int = 4 - - private[spark] override def asNullable: DateType = this -} - -case object DateType extends DateType - - -/** - * :: DeveloperApi :: - * Numeric data types. - * - * @group dataType - */ -abstract class NumericType extends AtomicType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[InternalType] -} - - -protected[sql] object NumericType { - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -/** Matcher for any expressions that evaluate to [[IntegralType]]s */ -protected[sql] object IntegralType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[IntegralType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[InternalType] -} - - -/** - * :: DeveloperApi :: - * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType - */ -@DeveloperApi -class LongType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "LongType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Long - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Long]] - private[sql] val integral = implicitly[Integral[Long]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the LongType is 8 bytes. - */ - override def defaultSize: Int = 8 - - override def simpleString: String = "bigint" - - private[spark] override def asNullable: LongType = this -} - -case object LongType extends LongType - - -/** - * :: DeveloperApi :: - * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType - */ -@DeveloperApi -class IntegerType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Int]] - private[sql] val integral = implicitly[Integral[Int]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the IntegerType is 4 bytes. - */ - override def defaultSize: Int = 4 - - override def simpleString: String = "int" - - private[spark] override def asNullable: IntegerType = this -} - -case object IntegerType extends IntegerType - - -/** - * :: DeveloperApi :: - * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType - */ -@DeveloperApi -class ShortType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Short - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Short]] - private[sql] val integral = implicitly[Integral[Short]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the ShortType is 2 bytes. - */ - override def defaultSize: Int = 2 - - override def simpleString: String = "smallint" - - private[spark] override def asNullable: ShortType = this -} - -case object ShortType extends ShortType - - -/** - * :: DeveloperApi :: - * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType - */ -@DeveloperApi -class ByteType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Byte - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Byte]] - private[sql] val integral = implicitly[Integral[Byte]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the ByteType is 1 byte. - */ - override def defaultSize: Int = 1 - - override def simpleString: String = "tinyint" - - private[spark] override def asNullable: ByteType = this -} - -case object ByteType extends ByteType - - -/** Matcher for any expressions that evaluate to [[FractionalType]]s */ -protected[sql] object FractionalType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[FractionalType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[InternalType] - private[sql] val asIntegral: Integral[InternalType] -} - - -/** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - - -/** - * :: DeveloperApi :: - * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. - * - * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType - */ -@DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { - - /** No-arg constructor for kryo. */ - protected def this() = this(null) - - private[sql] type InternalType = Decimal - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = Decimal.DecimalIsFractional - private[sql] val fractional = Decimal.DecimalIsFractional - private[sql] val ordering = Decimal.DecimalIsFractional - private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) - - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) - - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } - - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" - } - - /** - * The default size of a value of the DecimalType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - override def simpleString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal(10,0)" - } - - private[spark] override def asNullable: DecimalType = this -} - - -/** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { - val Unlimited: DecimalType = DecimalType(None) - - object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) - } - - object Expression { - def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) - case _ => None - } - } - - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] - - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } -} - - -/** - * :: DeveloperApi :: - * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType - */ -@DeveloperApi -class DoubleType private() extends FractionalType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Double - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Double]] - private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - private[sql] val asIntegral = DoubleAsIfIntegral - - /** - * The default size of a value of the DoubleType is 8 bytes. - */ - override def defaultSize: Int = 8 - - private[spark] override def asNullable: DoubleType = this -} - -case object DoubleType extends DoubleType - - -/** - * :: DeveloperApi :: - * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType - */ -@DeveloperApi -class FloatType private() extends FractionalType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Float - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Float]] - private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - private[sql] val asIntegral = FloatAsIfIntegral - - /** - * The default size of a value of the FloatType is 4 bytes. - */ - override def defaultSize: Int = 4 - - private[spark] override def asNullable: FloatType = this -} - -case object FloatType extends FloatType - - -object ArrayType { - /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ - def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) -} - - -/** - * :: DeveloperApi :: - * The data type for collections of multiple values. - * Internally these are represented as columns that contain a ``scala.collection.Seq``. - * - * Please use [[DataTypes.createArrayType()]] to create a specific instance. - * - * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` values. - * - * @param elementType The data type of values. - * @param containsNull Indicates if values have `null` values - * - * @group dataType - */ -@DeveloperApi -case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { - - /** No-arg constructor for kryo. */ - protected def this() = this(null, false) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append( - s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") - DataType.buildFormattedString(elementType, s"$prefix |", builder) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("elementType" -> elementType.jsonValue) ~ - ("containsNull" -> containsNull) - - /** - * The default size of a value of the ArrayType is 100 * the default size of the element type. - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * elementType.defaultSize - - override def simpleString: String = s"array<${elementType.simpleString}>" - - private[spark] override def asNullable: ArrayType = - ArrayType(elementType.asNullable, containsNull = true) -} - - -/** - * A field inside a StructType. - * @param name The name of this field. - * @param dataType The data type of this field. - * @param nullable Indicates if values of this field can be `null` values. - * @param metadata The metadata of this field. The metadata should be preserved during - * transformation if the content of the column is not modified, e.g, in selection. - */ -case class StructField( - name: String, - dataType: DataType, - nullable: Boolean = true, - metadata: Metadata = Metadata.empty) { - - /** No-arg constructor for kryo. */ - protected def this() = this(null, null) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") - DataType.buildFormattedString(dataType, s"$prefix |", builder) - } - - // override the default toString to be compatible with legacy parquet files. - override def toString: String = s"StructField($name,$dataType,$nullable)" - - private[sql] def jsonValue: JValue = { - ("name" -> name) ~ - ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) - } -} - - -object StructType { - protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) - - def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) - - def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) - } - - private[sql] def merge(left: DataType, right: DataType): DataType = - (left, right) match { - case (ArrayType(leftElementType, leftContainsNull), - ArrayType(rightElementType, rightContainsNull)) => - ArrayType( - merge(leftElementType, rightElementType), - leftContainsNull || rightContainsNull) - - case (MapType(leftKeyType, leftValueType, leftContainsNull), - MapType(rightKeyType, rightValueType, rightContainsNull)) => - MapType( - merge(leftKeyType, rightKeyType), - merge(leftValueType, rightValueType), - leftContainsNull || rightContainsNull) - - case (StructType(leftFields), StructType(rightFields)) => - val newFields = ArrayBuffer.empty[StructField] - - leftFields.foreach { - case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) - .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) - .foreach(newFields += _) - } - - rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) - .foreach(newFields += _) - - StructType(newFields) - - case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) - - case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) - if leftUdt.userClass == rightUdt.userClass => leftUdt - - case (leftType, rightType) if leftType == rightType => - leftType - - case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") - } -} - - -/** - * :: DeveloperApi :: - * A [[StructType]] object can be constructed by - * {{{ - * StructType(fields: Seq[StructField]) - * }}} - * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. - * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. - * If a provided name does not have a matching field, it will be ignored. For the case - * of extracting a single StructField, a `null` will be returned. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val struct = - * StructType( - * StructField("a", IntegerType, true) :: - * StructField("b", LongType, false) :: - * StructField("c", BooleanType, false) :: Nil) - * - * // Extract a single StructField. - * val singleField = struct("b") - * // singleField: StructField = StructField(b,LongType,false) - * - * // This struct does not have a field called "d". null will be returned. - * val nonExisting = struct("d") - * // nonExisting: StructField = null - * - * // Extract multiple StructFields. Field names are provided in a set. - * // A StructType object will be returned. - * val twoFields = struct(Set("b", "c")) - * // twoFields: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * - * // Any names without matching fields will be ignored. - * // For the case shown below, "d" will be ignored and - * // it is treated as struct(Set("b", "c")). - * val ignoreNonExisting = struct(Set("b", "c", "d")) - * // ignoreNonExisting: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * }}} - * - * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val innerStruct = - * StructType( - * StructField("f1", IntegerType, true) :: - * StructField("f2", LongType, false) :: - * StructField("f3", BooleanType, false) :: Nil) - * - * val struct = StructType( - * StructField("a", innerStruct, true) :: Nil) - * - * // Create a Row with the schema defined by struct - * val row = Row(Row(1, 2, true)) - * // row: Row = [[1,2,true]] - * }}} - * - * @group dataType - */ -@DeveloperApi -case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { - - /** No-arg constructor for kryo. */ - protected def this() = this(null) - - /** Returns all field names in an array. */ - def fieldNames: Array[String] = fields.map(_.name) - - private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap - private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap - - /** - * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not - * have a name matching the given name, `null` will be returned. - */ - def apply(name: String): StructField = { - nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - } - - /** - * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the - * original order of fields. Those names which do not have matching fields will be ignored. - */ - def apply(names: Set[String]): StructType = { - val nonExistFields = names -- fieldNamesSet - if (nonExistFields.nonEmpty) { - throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") - } - // Preserve the original order of fields. - StructType(fields.filter(f => names.contains(f.name))) - } - - /** - * Returns index of a given field - */ - def fieldIndex(name: String): Int = { - nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - } - - protected[sql] def toAttributes: Seq[AttributeReference] = - map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - - def treeString: String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - fields.foreach(field => field.buildFormattedString(prefix, builder)) - - builder.toString() - } - - def printTreeString(): Unit = println(treeString) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - fields.foreach(field => field.buildFormattedString(prefix, builder)) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("fields" -> map(_.jsonValue)) - - override def apply(fieldIndex: Int): StructField = fields(fieldIndex) - - override def length: Int = fields.length - - override def iterator: Iterator[StructField] = fields.iterator - - /** - * The default size of a value of the StructType is the total default sizes of all field types. - */ - override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum - - override def simpleString: String = { - val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") - s"struct<${fieldTypes.mkString(",")}>" - } - - /** - * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field - * B from `that`, - * - * 1. If A and B have the same name and data type, they are merged to a field C with the same name - * and data type. C is nullable if and only if either A or B is nullable. - * 2. If A doesn't exist in `that`, it's included in the result schema. - * 3. If B doesn't exist in `this`, it's also included in the result schema. - * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be - * thrown. - */ - private[sql] def merge(that: StructType): StructType = - StructType.merge(this, that).asInstanceOf[StructType] - - private[spark] override def asNullable: StructType = { - val newFields = fields.map { - case StructField(name, dataType, nullable, metadata) => - StructField(name, dataType.asNullable, nullable = true, metadata) - } - - StructType(newFields) - } -} - - -object MapType { - /** - * Construct a [[MapType]] object with the given key type and value type. - * The `valueContainsNull` is true. - */ - def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) -} - - -/** - * :: DeveloperApi :: - * The data type for Maps. Keys in a map are not allowed to have `null` values. - * - * Please use [[DataTypes.createMapType()]] to create a specific instance. - * - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType - */ -case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { - - /** No-arg constructor for kryo. */ - def this() = this(null, null, false) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- key: ${keyType.typeName}\n") - builder.append(s"$prefix-- value: ${valueType.typeName} " + - s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) - DataType.buildFormattedString(valueType, s"$prefix |", builder) - } - - override private[sql] def jsonValue: JValue = - ("type" -> typeName) ~ - ("keyType" -> keyType.jsonValue) ~ - ("valueType" -> valueType.jsonValue) ~ - ("valueContainsNull" -> valueContainsNull) - - /** - * The default size of a value of the MapType is - * 100 * (the default size of the key type + the default size of the value type). - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) - - override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - - private[spark] override def asNullable: MapType = - MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) -} - - -/** - * ::DeveloperApi:: - * The data type for User Defined Types (UDTs). - * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a `DataFrame` which has class X in the schema. - * - * For SparkSQL to recognize UDTs, the UDT must be annotated with - * [[SQLUserDefinedType]]. - * - * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. - * The conversion via `deserialize` occurs when reading from a `DataFrame`. - */ -@DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { - - /** Underlying storage type for this UDT */ - def sqlType: DataType - - /** Paired Python UDT class, if exists. */ - def pyUDT: String = null - - /** - * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in - * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. - */ - def serialize(obj: Any): Any - - /** Convert a SQL datum to the user type */ - def deserialize(datum: Any): UserType - - override private[sql] def jsonValue: JValue = { - ("type" -> "udt") ~ - ("class" -> this.getClass.getName) ~ - ("pyClass" -> pyUDT) ~ - ("sqlType" -> sqlType.jsonValue) - } - - /** - * Class object for the UserType - */ - def userClass: java.lang.Class[UserType] - - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - /** - * For UDT, asNullable will not change the nullability of its internal sqlType and just returns - * itself. - */ - private[spark] override def asNullable: UserDefinedType[UserType] = this -} From 73db132bf503341c7a5cf9409351c282a8464175 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 23 Apr 2015 16:08:14 -0700 Subject: [PATCH 052/110] [SPARK-6818] [SPARKR] Support column deletion in SparkR DataFrame API. Author: Sun Rui Closes #5655 from sun-rui/SPARK-6818 and squashes the following commits: 7c66570 [Sun Rui] [SPARK-6818][SPARKR] Support column deletion in SparkR DataFrame API. --- R/pkg/R/DataFrame.R | 8 +++++++- R/pkg/inst/tests/test_sparkSQL.R | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 861fe1c78b0db..b59b700af5dc9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -790,9 +790,12 @@ setMethod("$", signature(x = "DataFrame"), setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { - stopifnot(class(value) == "Column") + stopifnot(class(value) == "Column" || is.null(value)) cols <- columns(x) if (name %in% cols) { + if (is.null(value)) { + cols <- Filter(function(c) { c != name }, cols) + } cols <- lapply(cols, function(c) { if (c == name) { alias(value, name) @@ -802,6 +805,9 @@ setMethod("$<-", signature(x = "DataFrame"), }) nx <- select(x, cols) } else { + if (is.null(value)) { + return(x) + } nx <- withColumn(x, name, value) } x@sdf <- nx@sdf diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 25831ae2d9e18..af7a6c582047a 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -449,6 +449,11 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) }) test_that("select with column", { From 336f7f5373e5f6960ecd9967d3703c8507e329ec Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Thu, 23 Apr 2015 20:10:55 -0400 Subject: [PATCH 053/110] [SPARK-7037] [CORE] Inconsistent behavior for non-spark config properties in spark-shell and spark-submit When specifying non-spark properties (i.e. names don't start with spark.) in the command line and config file, spark-submit and spark-shell behave differently, causing confusion to users. Here is the summary- * spark-submit * --conf k=v => silently ignored * spark-defaults.conf => applied * spark-shell * --conf k=v => show a warning message and ignored * spark-defaults.conf => show a warning message and ignored I assume that ignoring non-spark properties is intentional. If so, it should always be ignored with a warning message in all cases. Author: Cheolsoo Park Closes #5617 from piaozhexiu/SPARK-7037 and squashes the following commits: 8957950 [Cheolsoo Park] Add IgnoreNonSparkProperties method fedd01c [Cheolsoo Park] Ignore non-spark properties with a warning message in all cases --- .../spark/deploy/SparkSubmitArguments.scala | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index faa8780288ea3..c896842943f2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -77,12 +77,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => - if (k.startsWith("spark.")) { - defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") - } else { - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") - } + defaultProperties(k) = v + if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } defaultProperties @@ -97,6 +93,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() + // Remove keys that don't start with "spark." from `sparkProperties`. + ignoreNonSparkProperties() // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() @@ -117,6 +115,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } + /** + * Remove keys that don't start with "spark." from `sparkProperties`. + */ + private def ignoreNonSparkProperties(): Unit = { + sparkProperties.foreach { case (k, v) => + if (!k.startsWith("spark.")) { + sparkProperties -= k + SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + } + } + } + /** * Load arguments from environment variables, Spark properties etc. */ From 2d010f7afe6ac8e67e07da6bea700e9e8c9e6cc2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 23 Apr 2015 18:52:55 -0700 Subject: [PATCH 054/110] [SPARK-7060][SQL] Add alias function to python dataframe This pr tries to provide a way to let python users workaround https://issues.apache.org/jira/browse/SPARK-6231. Author: Yin Huai Closes #5634 from yhuai/pythonDFAlias and squashes the following commits: 8465acd [Yin Huai] Add an alias to a Python DF. --- python/pyspark/sql/dataframe.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c8c30ce4022c8..4759f5fe783ad 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -452,6 +452,20 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @ignore_unicode_prefix + def alias(self, alias): + """Returns a new :class:`DataFrame` with an alias set. + + >>> from pyspark.sql.functions import * + >>> df_as1 = df.alias("df_as1") + >>> df_as2 = df.alias("df_as2") + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') + >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + """ + assert isinstance(alias, basestring), "alias should be a string" + return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + @ignore_unicode_prefix def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. From 67bccbda1e3ed7db2753daa7e6ae8b1441356177 Mon Sep 17 00:00:00 2001 From: Ken Geis Date: Thu, 23 Apr 2015 20:45:33 -0700 Subject: [PATCH 055/110] Update sql-programming-guide.md fix typo Author: Ken Geis Closes #5674 from kgeis/patch-1 and squashes the following commits: 5ae67de [Ken Geis] Update sql-programming-guide.md --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b2022546268a7..49b1e69f0e9db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1364,7 +1364,7 @@ the Data Sources API. The following options are supported: driver - The class name of the JDBC driver needed to connect to this URL. This class with be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem. From d3a302defc45768492dec9da4c40d78d28997a65 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 21:21:03 -0700 Subject: [PATCH 056/110] [SQL] Fixed expression data type matching. Also took the chance to improve documentation for various types. Author: Reynold Xin Closes #5675 from rxin/data-type-matching-expr and squashes the following commits: 0f31856 [Reynold Xin] One more function documentation. 27c1973 [Reynold Xin] Added more documentation. 336a36d [Reynold Xin] [SQL] Fixed expression data type matching. --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 50 +++++++++++++++---- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cbe520347385d..dbc92fb93e95e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -279,7 +279,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children - case EqualTo(e1: BinaryType, e2: BinaryType) => + case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e6bfcd9adfeb1..06bff7d70edbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -40,32 +40,46 @@ import org.apache.spark.util.Utils */ @DeveloperApi abstract class DataType { - /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ BinaryType(), StringType) => + * ... + * }}} + */ + private[sql] def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType == this => true case _ => false } - /** The default size of a value of this data type. */ + /** + * The default size of a value of this data type, used internally for size estimation. + */ def defaultSize: Int + /** Name of the type used in JSON serialization. */ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase private[sql] def jsonValue: JValue = typeName + /** The compact JSON representation of this data type. */ def json: String = compact(render(jsonValue)) + /** The pretty (i.e. indented) JSON representation of this data type. */ def prettyJson: String = pretty(render(jsonValue)) + /** Readable string representation for the type. */ def simpleString: String = typeName - /** Check if `this` and `other` are the same data type when ignoring nullability - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = DataType.equalsIgnoreNullability(this, other) - /** Returns the same data type but set all nullability fields are true + /** + * Returns the same data type but set all nullability fields are true * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType @@ -104,12 +118,25 @@ abstract class NumericType extends AtomicType { private[sql] object NumericType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } -/** Matcher for any expressions that evaluate to [[IntegralType]]s */ private[sql] object IntegralType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[IntegralType] => true case _ => false @@ -122,9 +149,14 @@ private[sql] abstract class IntegralType extends NumericType { } - -/** Matcher for any expressions that evaluate to [[FractionalType]]s */ private[sql] object FractionalType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[FractionalType] => true case _ => false From 4c722d77ae7e77eeaa7531687fa9bd6050344d18 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 22:39:00 -0700 Subject: [PATCH 057/110] Fixed a typo from the previous commit. --- .../src/main/scala/org/apache/spark/sql/types/DataType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 06bff7d70edbc..0992a7c311ee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.Utils @DeveloperApi abstract class DataType { /** - * Enables matching against NumericType for expressions: + * Enables matching against DataType for expressions: * {{{ * case Cast(child @ BinaryType(), StringType) => * ... From 8509519d8bcf99e2d1b5e21da514d51357f9116d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 24 Apr 2015 00:39:29 -0700 Subject: [PATCH 058/110] [SPARK-5894] [ML] Add polynomial mapper See [SPARK-5894](https://issues.apache.org/jira/browse/SPARK-5894). Author: Xusen Yin Author: Xiangrui Meng Closes #5245 from yinxusen/SPARK-5894 and squashes the following commits: dc461a6 [Xusen Yin] merge polynomial expansion v2 6d0c3cc [Xusen Yin] Merge branch 'SPARK-5894' of https://github.com/mengxr/spark into mengxr-SPARK-5894 57bfdd5 [Xusen Yin] Merge branch 'master' into SPARK-5894 3d02a7d [Xusen Yin] Merge branch 'master' into SPARK-5894 a067da2 [Xiangrui Meng] a new approach for poly expansion 0789d81 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5894 4e9aed0 [Xusen Yin] fix test suite 95d8fb9 [Xusen Yin] fix sparse vector indices 8d39674 [Xusen Yin] fix sparse vector expansion error 5998dd6 [Xusen Yin] fix dense vector fillin fa3ade3 [Xusen Yin] change the functional code into imperative one to speedup b70e7e1 [Xusen Yin] remove useless case class 6fa236f [Xusen Yin] fix vector slice error daff601 [Xusen Yin] fix index error of sparse vector 6bd0a10 [Xusen Yin] merge repeated features 419f8a2 [Xusen Yin] need to merge same columns 4ebf34e [Xusen Yin] add test suite of polynomial expansion 372227c [Xusen Yin] add polynomial expansion --- .../ml/feature/PolynomialExpansion.scala | 167 ++++++++++++++++++ .../ml/feature/PolynomialExpansionSuite.scala | 104 +++++++++++ 2 files changed, 271 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala new file mode 100644 index 0000000000000..c3a59a361d0e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, + * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an + * expansion of a product of sums expresses it as a sum of products by using the fact that + * multiplication distributes over addition". Take a 2-variable feature vector as an example: + * `(x, y)`, if we want to expand it with degree 2, then we get `(x, y, x * x, x * y, y * y)`. + */ +@AlphaComponent +class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + /** + * The polynomial degree to expand, which should be larger than 1. + * @group param + */ + val degree = new IntParam(this, "degree", "the polynomial degree to expand") + setDefault(degree -> 2) + + /** @group getParam */ + def getDegree: Int = getOrDefault(degree) + + /** @group setParam */ + def setDegree(value: Int): this.type = set(degree, value) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v => + val d = paramMap(degree) + PolynomialExpansion.expand(v, d) + } + + override protected def outputDataType: DataType = new VectorUDT() +} + +/** + * The expansion is done via recursion. Given n features and degree d, the size after expansion is + * (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the + * function that expands [a, b, c] to their monomials of degree 3. We have the following recursion: + * + * {{{ + * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3] + * }}} + * + * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the + * current index and increment it properly for sparse input. + */ +object PolynomialExpansion { + + private def choose(n: Int, k: Int): Int = { + Range(n, n - k, -1).product / Range(k, 1, -1).product + } + + private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) + + private def expandDense( + values: Array[Double], + lastIdx: Int, + degree: Int, + multiplier: Double, + polyValues: Array[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + polyValues(curPolyIdx) = multiplier + } else { + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + var alpha = multiplier + var i = 0 + var curStart = curPolyIdx + while (i <= degree && alpha != 0.0) { + curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastIdx + 1, degree) + } + + private def expandSparse( + indices: Array[Int], + values: Array[Double], + lastIdx: Int, + lastFeatureIdx: Int, + degree: Int, + multiplier: Double, + polyIndices: mutable.ArrayBuilder[Int], + polyValues: mutable.ArrayBuilder[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + polyIndices += curPolyIdx + polyValues += multiplier + } else { + // Skip all zeros at the tail. + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + val lastFeatureIdx1 = indices(lastIdx) - 1 + var alpha = multiplier + var curStart = curPolyIdx + var i = 0 + while (i <= degree && alpha != 0.0) { + curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha, + polyIndices, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastFeatureIdx + 1, degree) + } + + private def expand(dv: DenseVector, degree: Int): DenseVector = { + val n = dv.size + val polySize = getPolySize(n, degree) + val polyValues = new Array[Double](polySize) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, 0) + new DenseVector(polyValues) + } + + private def expand(sv: SparseVector, degree: Int): SparseVector = { + val polySize = getPolySize(sv.size, degree) + val nnz = sv.values.length + val nnzPolySize = getPolySize(nnz, degree) + val polyIndices = mutable.ArrayBuilder.make[Int] + polyIndices.sizeHint(nnzPolySize) + val polyValues = mutable.ArrayBuilder.make[Double] + polyValues.sizeHint(nnzPolySize) + expandSparse( + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, 0) + new SparseVector(polySize, polyIndices.result(), polyValues.result()) + } + + def expand(v: Vector, degree: Int): Vector = { + v match { + case dv: DenseVector => expand(dv, degree) + case sv: SparseVector => expand(sv, degree) + case _ => throw new IllegalArgumentException + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala new file mode 100644 index 0000000000000..b0a537be42dfd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} +import org.scalatest.exceptions.TestFailedException + +class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("Polynomial expansion with default parameter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5), Array(1.0, -2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(1.0, -2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(Array(1.0) ++ Array.fill[Double](9)(0.0)), + Vectors.dense(1.0, 0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(10, Array(0), Array(1.0))) + + val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + + test("Polynomial expansion with setter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(20, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + Array(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(Array(1.0) ++ Array.fill[Double](19)(0.0)), + Vectors.dense(1.0, 0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(20, Array(0), Array(1.0))) + + val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } +} + From 78b39c7e0de8c9dc748cfbf8f78578a9524b6a94 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 24 Apr 2015 08:27:48 -0700 Subject: [PATCH 059/110] [SPARK-7115] [MLLIB] skip the very first 1 in poly expansion yinxusen Author: Xiangrui Meng Closes #5681 from mengxr/SPARK-7115 and squashes the following commits: 9ac27cd [Xiangrui Meng] skip the very first 1 in poly expansion --- .../ml/feature/PolynomialExpansion.scala | 22 +++++++++++-------- .../ml/feature/PolynomialExpansionSuite.scala | 22 +++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index c3a59a361d0e2..d855f04799ae7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -87,7 +87,9 @@ object PolynomialExpansion { if (multiplier == 0.0) { // do nothing } else if (degree == 0 || lastIdx < 0) { - polyValues(curPolyIdx) = multiplier + if (curPolyIdx >= 0) { // skip the very first 1 + polyValues(curPolyIdx) = multiplier + } } else { val v = values(lastIdx) val lastIdx1 = lastIdx - 1 @@ -116,8 +118,10 @@ object PolynomialExpansion { if (multiplier == 0.0) { // do nothing } else if (degree == 0 || lastIdx < 0) { - polyIndices += curPolyIdx - polyValues += multiplier + if (curPolyIdx >= 0) { // skip the very first 1 + polyIndices += curPolyIdx + polyValues += multiplier + } } else { // Skip all zeros at the tail. val v = values(lastIdx) @@ -139,8 +143,8 @@ object PolynomialExpansion { private def expand(dv: DenseVector, degree: Int): DenseVector = { val n = dv.size val polySize = getPolySize(n, degree) - val polyValues = new Array[Double](polySize) - expandDense(dv.values, n - 1, degree, 1.0, polyValues, 0) + val polyValues = new Array[Double](polySize - 1) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, -1) new DenseVector(polyValues) } @@ -149,12 +153,12 @@ object PolynomialExpansion { val nnz = sv.values.length val nnzPolySize = getPolySize(nnz, degree) val polyIndices = mutable.ArrayBuilder.make[Int] - polyIndices.sizeHint(nnzPolySize) + polyIndices.sizeHint(nnzPolySize - 1) val polyValues = mutable.ArrayBuilder.make[Double] - polyValues.sizeHint(nnzPolySize) + polyValues.sizeHint(nnzPolySize - 1) expandSparse( - sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, 0) - new SparseVector(polySize, polyIndices.result(), polyValues.result()) + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, -1) + new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } def expand(v: Vector, degree: Int): Vector = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index b0a537be42dfd..c1d64fba0aa8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -44,11 +44,11 @@ class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { ) val twoDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5), Array(1.0, -2.0, 4.0, 2.3, -4.6, 5.29)), - Vectors.dense(1.0, -2.0, 4.0, 2.3, -4.6, 5.29), - Vectors.dense(Array(1.0) ++ Array.fill[Double](9)(0.0)), - Vectors.dense(1.0, 0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), - Vectors.sparse(10, Array(0), Array(1.0))) + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") @@ -76,13 +76,13 @@ class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { ) val threeDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(20, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), - Array(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), - Vectors.dense(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), - Vectors.dense(Array(1.0) ++ Array.fill[Double](19)(0.0)), - Vectors.dense(1.0, 0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), - Vectors.sparse(20, Array(0), Array(1.0))) + Vectors.sparse(19, Array.empty, Array.empty)) val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") From 6e57d57b32ba2aa0514692074897b5edd34e0dd6 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 24 Apr 2015 08:29:49 -0700 Subject: [PATCH 060/110] [SPARK-6528] [ML] Add IDF transformer See [SPARK-6528](https://issues.apache.org/jira/browse/SPARK-6528). Add IDF transformer in ML package. Author: Xusen Yin Closes #5266 from yinxusen/SPARK-6528 and squashes the following commits: 741db31 [Xusen Yin] get param from new paramMap d169967 [Xusen Yin] add final to param and IDF class c9c3759 [Xusen Yin] simplify test suite 5867c09 [Xusen Yin] refine IDF transformer with new interfaces 7727cae [Xusen Yin] Merge branch 'master' into SPARK-6528 4338a37 [Xusen Yin] Merge branch 'master' into SPARK-6528 aef2cdf [Xusen Yin] add doc and group for param 5760b49 [Xusen Yin] fix code style 2add691 [Xusen Yin] fix code style and test 03fbecb [Xusen Yin] remove duplicated code 2aa4be0 [Xusen Yin] clean test suite 4802c67 [Xusen Yin] add IDF transformer and test suite --- .../org/apache/spark/ml/feature/IDF.scala | 116 ++++++++++++++++++ .../apache/spark/ml/feature/IDFSuite.scala | 101 +++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala new file mode 100644 index 0000000000000..e6a62d998bb97 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * Params for [[IDF]] and [[IDFModel]]. + */ +private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { + + /** + * The minimum of documents in which a term should appear. + * @group param + */ + final val minDocFreq = new IntParam( + this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + + setDefault(minDocFreq -> 0) + + /** @group getParam */ + def getMinDocFreq: Int = getOrDefault(minDocFreq) + + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Compute the Inverse Document Frequency (IDF) given a collection of documents. + */ +@AlphaComponent +final class IDF extends Estimator[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF(map(minDocFreq)).fit(input) + val model = new IDFModel(this, map, idf) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[IDF]]. + */ +@AlphaComponent +class IDFModel private[ml] ( + override val parent: IDF, + override val fittingParamMap: ParamMap, + idfModel: feature.IDFModel) + extends Model[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val idf = udf { vec: Vector => idfModel.transform(vec) } + dataset.withColumn(map(outputCol), idf(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala new file mode 100644 index 0000000000000..eaee3443c1f23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class IDFSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { + dataSet.map { + case data: DenseVector => + val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } + Vectors.dense(res) + case data: SparseVector => + val res = data.indices.zip(data.values).map { case (id, value) => + (id, value * model(id)) + } + Vectors.sparse(data.size, res) + } + } + + test("compute IDF with default parameter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((numOfData + 1.0) / (x + 1.0)) + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } + + test("compute IDF with setter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .setMinDocFreq(1) + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} From ebb77b2aff085e71906b5de9d266ded89051af82 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 24 Apr 2015 11:00:19 -0700 Subject: [PATCH 061/110] [SPARK-7033] [SPARKR] Clean usage of split. Use partition instead where applicable. Author: Sun Rui Closes #5628 from sun-rui/SPARK-7033 and squashes the following commits: 046bc9e [Sun Rui] Clean split usage in tests. d531c86 [Sun Rui] [SPARK-7033][SPARKR] Clean usage of split. Use partition instead where applicable. --- R/pkg/R/RDD.R | 36 ++++++++++++++++++------------------ R/pkg/R/context.R | 20 ++++++++++---------- R/pkg/R/pairRDD.R | 8 ++++---- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_rdd.R | 12 ++++++------ 5 files changed, 39 insertions(+), 39 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 128431334ca52..cc09efb1e5418 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -91,8 +91,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { - pipelinedFunc <- function(split, iterator) { - func(split, prev@func(split, iterator)) + pipelinedFunc <- function(partIndex, part) { + func(partIndex, prev@func(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -306,7 +306,7 @@ setMethod("numPartitions", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "splits") + partitions <- callJMethod(jrdd, "partitions") callJMethod(partitions, "size") }) @@ -452,8 +452,8 @@ setMethod("countByValue", setMethod("lapply", signature(X = "RDD", FUN = "function"), function(X, FUN) { - func <- function(split, iterator) { - lapply(iterator, FUN) + func <- function(partIndex, part) { + lapply(part, FUN) } lapplyPartitionsWithIndex(X, func) }) @@ -538,8 +538,8 @@ setMethod("mapPartitions", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 5L) -#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { -#' split * Reduce("+", part) }) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) #' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 #'} #' @rdname lapplyPartitionsWithIndex @@ -813,7 +813,7 @@ setMethod("distinct", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' rdd <- parallelize(sc, 1:10) #' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements #' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates #'} @@ -825,14 +825,14 @@ setMethod("sampleRDD", function(x, withReplacement, fraction, seed) { # The sampler: takes a partition and returns its sampled version. - samplingFunc <- function(split, part) { + samplingFunc <- function(partIndex, part) { set.seed(seed) res <- vector("list", length(part)) len <- 0 # Discards some random values to ensure each partition has a # different random seed. - runif(split) + runif(partIndex) for (elem in part) { if (withReplacement) { @@ -989,8 +989,8 @@ setMethod("coalesce", function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) if (shuffle || numPartitions > SparkR::numPartitions(x)) { - func <- function(s, part) { - set.seed(s) # split as seed + func <- function(partIndex, part) { + set.seed(partIndex) # partIndex as seed start <- as.integer(sample(numPartitions, 1) - 1) lapply(seq_along(part), function(i) { @@ -1035,7 +1035,7 @@ setMethod("saveAsObjectFile", #' Save this RDD as a text file, using string representations of elements. #' #' @param x The RDD to save -#' @param path The directory where the splits of the text file are saved +#' @param path The directory where the partitions of the text file are saved #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -1335,10 +1335,10 @@ setMethod("zipWithUniqueId", function(x) { n <- numPartitions(x) - partitionFunc <- function(split, part) { + partitionFunc <- function(partIndex, part) { mapply( function(item, index) { - list(item, (index - 1) * n + split) + list(item, (index - 1) * n + partIndex) }, part, seq_along(part), @@ -1382,11 +1382,11 @@ setMethod("zipWithIndex", startIndices <- Reduce("+", nums, accumulate = TRUE) } - partitionFunc <- function(split, part) { - if (split == 0) { + partitionFunc <- function(partIndex, part) { + if (partIndex == 0) { startIndex <- 0 } else { - startIndex <- startIndices[[split]] + startIndex <- startIndices[[partIndex]] } mapply( diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index ebbb8fba1052d..b4845b6948997 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -17,12 +17,12 @@ # context.R: SparkContext driven functions -getMinSplits <- function(sc, minSplits) { - if (is.null(minSplits)) { +getMinPartitions <- function(sc, minPartitions) { + if (is.null(minPartitions)) { defaultParallelism <- callJMethod(sc, "defaultParallelism") - minSplits <- min(defaultParallelism, 2) + minPartitions <- min(defaultParallelism, 2) } - as.integer(minSplits) + as.integer(minPartitions) } #' Create an RDD from a text file. @@ -33,7 +33,7 @@ getMinSplits <- function(sc, minSplits) { #' #' @param sc SparkContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. -#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default #' value is chosen based on available parallelism. #' @return RDD where each item is of type \code{character} #' @export @@ -42,13 +42,13 @@ getMinSplits <- function(sc, minSplits) { #' sc <- sparkR.init() #' lines <- textFile(sc, "myfile.txt") #'} -textFile <- function(sc, path, minSplits = NULL) { +textFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) #' Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + jrdd <- callJMethod(sc, "textFile", path, getMinPartitions(sc, minPartitions)) # jrdd is of type JavaRDD[String] RDD(jrdd, "string") } @@ -60,7 +60,7 @@ textFile <- function(sc, path, minSplits = NULL) { #' #' @param sc SparkContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. -#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default #' value is chosen based on available parallelism. #' @return RDD containing serialized R objects. #' @seealso saveAsObjectFile @@ -70,13 +70,13 @@ textFile <- function(sc, path, minSplits = NULL) { #' sc <- sparkR.init() #' rdd <- objectFile(sc, "myfile") #'} -objectFile <- function(sc, path, minSplits = NULL) { +objectFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) #' Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + jrdd <- callJMethod(sc, "objectFile", path, getMinPartitions(sc, minPartitions)) # Assume the RDD contains serialized R objects. RDD(jrdd, "byte") } diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 13efebc11c46e..f99b474ff8f2a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -206,8 +206,8 @@ setMethod("partitionBy", get(name, .broadcastNames) }) jrdd <- getJRDD(x) - # We create a PairwiseRRDD that extends RDD[(Array[Byte], - # Array[Byte])], where the key is the hashed split, the value is + # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], + # where the key is the target partition number, the value is # the content (key-val pairs). pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", callJMethod(jrdd, "rdd"), @@ -866,8 +866,8 @@ setMethod("sampleByKey", } # The sampler: takes a partition and returns its sampled version. - samplingFunc <- function(split, part) { - set.seed(bitwXor(seed, split)) + samplingFunc <- function(partIndex, part) { + set.seed(bitwXor(seed, partIndex)) res <- vector("list", length(part)) len <- 0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 23305d3c67074..0e7b7bd5a5b34 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -501,7 +501,7 @@ appendPartitionLengths <- function(x, other) { # A result RDD. mergePartitions <- function(rdd, zip) { serializerMode <- getSerializedMode(rdd) - partitionFunc <- function(split, part) { + partitionFunc <- function(partIndex, part) { len <- length(part) if (len > 0) { if (serializerMode == "byte") { diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 3ba7d1716302a..d55af93e3e50a 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -105,8 +105,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( - rdd2, function(split, part) { - part <- as.list(unlist(part) * split + i) + rdd2, function(partIndex, part) { + part <- as.list(unlist(part) * partIndex + i) }) rdd2 <- lapply(rdd2, function(x) x + x) actual <- collect(rdd2) @@ -121,8 +121,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp # PipelinedRDD rdd2 <- lapplyPartitionsWithIndex( rdd2, - function(split, part) { - part <- as.list(unlist(part) * split) + function(partIndex, part) { + part <- as.list(unlist(part) * partIndex) }) cache(rdd2) @@ -174,13 +174,13 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { - func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } - mkTup <- function(splitIndex, part) { list(splitIndex, part) } + mkTup <- function(partIndex, part) { list(partIndex, part) } actual <- collect(lapplyPartitionsWithIndex( partitionBy(pairsRDD, 2L, partitionByParity), mkTup), From caf0136ec5838cf5bf61f39a5b3474a505a6ae11 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 24 Apr 2015 12:52:07 -0700 Subject: [PATCH 062/110] [SPARK-6852] [SPARKR] Accept numeric as numPartitions in SparkR. Author: Sun Rui Closes #5613 from sun-rui/SPARK-6852 and squashes the following commits: abaf02e [Sun Rui] Change the type of default numPartitions from integer to numeric in generics.R. 29d67c1 [Sun Rui] [SPARK-6852][SPARKR] Accept numeric as numPartitions in SparkR. --- R/pkg/R/RDD.R | 2 +- R/pkg/R/generics.R | 12 ++++++------ R/pkg/R/pairRDD.R | 24 ++++++++++++------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index cc09efb1e5418..1662d6bb3b1ac 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -967,7 +967,7 @@ setMethod("keyBy", setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { - coalesce(x, numToInt(numPartitions), TRUE) + coalesce(x, numPartitions, TRUE) }) #' Return a new RDD that is reduced into numPartitions partitions. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6c6233390134c..34dbe84051c50 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -60,7 +60,7 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) #' @rdname distinct #' @export -setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) +setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) #' @rdname filterRDD #' @export @@ -182,7 +182,7 @@ setGeneric("setName", function(x, name) { standardGeneric("setName") }) #' @rdname sortBy #' @export setGeneric("sortBy", - function(x, func, ascending = TRUE, numPartitions = 1L) { + function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") }) @@ -244,7 +244,7 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") #' @rdname intersection #' @export -setGeneric("intersection", function(x, other, numPartitions = 1L) { +setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) #' @rdname keys @@ -346,21 +346,21 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export setGeneric("sortByKey", - function(x, ascending = TRUE, numPartitions = 1L) { + function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) #' @rdname subtract #' @export setGeneric("subtract", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtract") }) #' @rdname subtractByKey #' @export setGeneric("subtractByKey", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index f99b474ff8f2a..9791e55791bae 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -190,7 +190,7 @@ setMethod("flatMapValues", #' @rdname partitionBy #' @aliases partitionBy,RDD,integer-method setMethod("partitionBy", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { #if (missing(partitionFunc)) { @@ -211,7 +211,7 @@ setMethod("partitionBy", # the content (key-val pairs). pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", callJMethod(jrdd, "rdd"), - as.integer(numPartitions), + numToInt(numPartitions), serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, @@ -221,7 +221,7 @@ setMethod("partitionBy", # Create a corresponding partitioner. rPartitioner <- newJObject("org.apache.spark.HashPartitioner", - as.integer(numPartitions)) + numToInt(numPartitions)) # Call partitionBy on the obtained PairwiseRDD. javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") @@ -256,7 +256,7 @@ setMethod("partitionBy", #' @rdname groupByKey #' @aliases groupByKey,RDD,integer-method setMethod("groupByKey", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { shuffled <- partitionBy(x, numPartitions) groupVals <- function(part) { @@ -315,7 +315,7 @@ setMethod("groupByKey", #' @rdname reduceByKey #' @aliases reduceByKey,RDD,integer-method setMethod("reduceByKey", - signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { reduceVals <- function(part) { vals <- new.env() @@ -422,7 +422,7 @@ setMethod("reduceByKeyLocally", #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", - mergeCombiners = "ANY", numPartitions = "integer"), + mergeCombiners = "ANY", numPartitions = "numeric"), function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { combineLocally <- function(part) { combiners <- new.env() @@ -483,7 +483,7 @@ setMethod("combineByKey", #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", - combOp = "ANY", numPartitions = "integer"), + combOp = "ANY", numPartitions = "numeric"), function(x, zeroValue, seqOp, combOp, numPartitions) { createCombiner <- function(v) { do.call(seqOp, list(zeroValue, v)) @@ -514,7 +514,7 @@ setMethod("aggregateByKey", #' @aliases foldByKey,RDD,ANY,ANY,integer-method setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", - func = "ANY", numPartitions = "integer"), + func = "ANY", numPartitions = "numeric"), function(x, zeroValue, func, numPartitions) { aggregateByKey(x, zeroValue, func, func, numPartitions) }) @@ -553,7 +553,7 @@ setMethod("join", joinTaggedList(v, list(FALSE, FALSE)) } - joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -582,7 +582,7 @@ setMethod("join", #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method setMethod("leftOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -619,7 +619,7 @@ setMethod("leftOuterJoin", #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method setMethod("rightOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -659,7 +659,7 @@ setMethod("rightOuterJoin", #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method setMethod("fullOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) From 438859eb7c4e605bb4041d9a547a16be9c827c75 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Fri, 24 Apr 2015 17:57:41 -0400 Subject: [PATCH 063/110] [SPARK-6122] [CORE] Upgrade tachyon-client version to 0.6.3 This is a reopening of #4867. A short summary of the issues resolved from the previous PR: 1. HTTPClient version mismatch: Selenium (used for UI tests) requires version 4.3.x, and Tachyon included 4.2.5 through a transitive dependency of its shaded thrift jar. To address this, Tachyon 0.6.3 will promote the transitive dependencies of the shaded jar so they can be excluded in spark. 2. Jackson-Mapper-ASL version mismatch: In lower versions of hadoop-client (ie. 1.0.4), version 1.0.1 is included. The parquet library used in spark sql requires version 1.8+. Its unclear to me why upgrading tachyon-client would cause this dependency to break. The solution was to exclude jackson-mapper-asl from hadoop-client. It seems that the dependency management in spark-parent will not work on transitive dependencies, one way to make sure jackson-mapper-asl is included with the correct version is to add it as a top level dependency. The best solution would be to exclude the dependency in the modules which require a higher version, but that did not fix the unit tests. Any suggestions on the best way to solve this would be appreciated! Author: Calvin Jia Closes #5354 from calvinjia/upgrade_tachyon_0.6.3 and squashes the following commits: 0eefe4d [Calvin Jia] Handle httpclient version in maven dependency management. Remove httpclient version setting from profiles. 7c00dfa [Calvin Jia] Set httpclient version to 4.3.2 for selenium. Specify version of httpclient for sql/hive (previously 4.2.5 transitive dependency of libthrift). 9263097 [Calvin Jia] Merge master to test latest changes dbfc1bd [Calvin Jia] Use Tachyon 0.6.4 for cleaner dependencies. e2ff80a [Calvin Jia] Exclude the jetty and curator promoted dependencies from tachyon-client. a3a29da [Calvin Jia] Update tachyon-client exclusions. 0ae6c97 [Calvin Jia] Change tachyon version to 0.6.3 a204df9 [Calvin Jia] Update make distribution tachyon version. a93c94f [Calvin Jia] Exclude jackson-mapper-asl from hadoop client since it has a lower version than spark's expected version. a8a923c [Calvin Jia] Exclude httpcomponents from Tachyon 910fabd [Calvin Jia] Update to master eed9230 [Calvin Jia] Update tachyon version to 0.6.1. 11907b3 [Calvin Jia] Use TachyonURI for tachyon paths instead of strings. 71bf441 [Calvin Jia] Upgrade Tachyon client version to 0.6.0. --- assembly/pom.xml | 10 ---------- core/pom.xml | 6 +++++- .../spark/storage/TachyonBlockManager.scala | 16 ++++++++-------- .../main/scala/org/apache/spark/util/Utils.scala | 4 +++- examples/pom.xml | 5 ----- launcher/pom.xml | 6 ++++++ make-distribution.sh | 2 +- pom.xml | 12 +++++++++++- sql/hive/pom.xml | 5 +++++ 9 files changed, 39 insertions(+), 27 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index f1f8b0d3682e2..20593e710dedb 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -213,16 +213,6 @@ - - kinesis-asl - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - diff --git a/core/pom.xml b/core/pom.xml index e80829b7a7f3d..5e89d548cd47f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -74,6 +74,10 @@ javax.servlet servlet-api + + org.codehaus.jackson + jackson-mapper-asl + @@ -275,7 +279,7 @@ org.tachyonproject tachyon-client - 0.5.0 + 0.6.4 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 951897cead996..583f1fdf0475b 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import tachyon.TachyonURI +import tachyon.client.{TachyonFile, TachyonFS} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(master) else null + val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -113,7 +113,7 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2feb7341b159b..667aa168e7ef3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,6 +42,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ + +import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -955,7 +957,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(dir.getPath(), true)) { + if (!client.delete(new TachyonURI(dir.getPath()), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } diff --git a/examples/pom.xml b/examples/pom.xml index afd7c6d52f0dd..df1717403b673 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -390,11 +390,6 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - diff --git a/launcher/pom.xml b/launcher/pom.xml index 182e5f60218db..ebfa7685eaa18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,6 +68,12 @@ org.apache.hadoop hadoop-client test + + + org.codehaus.jackson + jackson-mapper-asl + + diff --git a/make-distribution.sh b/make-distribution.sh index 738a9c4d69601..cb65932b4abc0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.5.0" +TACHYON_VERSION="0.6.4" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/pom.xml b/pom.xml index bcc2f57f1af5d..4b0b0c85eff21 100644 --- a/pom.xml +++ b/pom.xml @@ -146,7 +146,7 @@ 0.7.1 1.8.3 1.1.0 - 4.2.6 + 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 @@ -420,6 +420,16 @@ jsr305 1.3.9 + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + org.apache.httpcomponents + httpcore + ${commons.httpclient.version} + org.seleniumhq.selenium selenium-java diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 04440076a26a3..21dce8d8a565a 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -59,6 +59,11 @@ ${hive.group} hive-exec + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + org.codehaus.jackson jackson-mapper-asl From d874f8b546d8fae95bc92d8461b8189e51cb731b Mon Sep 17 00:00:00 2001 From: linweizhong Date: Fri, 24 Apr 2015 20:23:19 -0700 Subject: [PATCH 064/110] [PySpark][Minor] Update sql example, so that can read file correctly To run Spark, default will read file from HDFS if we don't set the schema. Author: linweizhong Closes #5684 from Sephiroth-Lin/pyspark_example_minor and squashes the following commits: 19fe145 [linweizhong] Update example sql.py, so that can read file correctly --- examples/src/main/python/sql.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 87d7b088f077b..2c188759328f2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -18,6 +18,7 @@ from __future__ import print_function import os +import sys from pyspark import SparkContext from pyspark.sql import SQLContext @@ -50,7 +51,11 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + if len(sys.argv) < 2: + path = "file://" + \ + os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + else: + path = sys.argv[1] # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root From 59b7cfc41b2c06fbfbf6aca16c1619496a8d1d00 Mon Sep 17 00:00:00 2001 From: Deborah Siegel Date: Fri, 24 Apr 2015 20:25:07 -0700 Subject: [PATCH 065/110] [SPARK-7136][Docs] Spark SQL and DataFrame Guide fix example file and paths Changes example file for Generic Load/Save Functions to users.parquet rather than people.parquet which doesn't exist unless a later example has already been executed. Also adds filepaths. Author: Deborah Siegel Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Closes #5693 from d3borah/master and squashes the following commits: 4d5e43b [Deborah Siegel] sparkSQL doc change b15a497 [Deborah Siegel] Revert "sparkSQL doc change" 5a2863c [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' 91972fc [DEBORAH SIEGEL] sparkSQL doc change f000e59 [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' db54173 [DEBORAH SIEGEL] fixed aggregateMessages example in graphX doc --- docs/sql-programming-guide.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 49b1e69f0e9db..b8233ae06fdf3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -681,8 +681,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
{% highlight scala %} -val df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +val df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %}
@@ -691,8 +691,8 @@ df.select("name", "age").save("namesAndAges.parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.parquet"); -df.select("name", "age").save("namesAndAges.parquet"); +DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% endhighlight %} @@ -702,8 +702,8 @@ df.select("name", "age").save("namesAndAges.parquet"); {% highlight python %} -df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %} @@ -722,7 +722,7 @@ using this syntax.
{% highlight scala %} -val df = sqlContext.load("people.json", "json") +val df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} @@ -732,7 +732,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.json", "json"); +DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% endhighlight %} @@ -743,7 +743,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("people.json", "json") +df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} From cca9905b93483614b330b09b36c6526b551e17dc Mon Sep 17 00:00:00 2001 From: KeheCAI Date: Sat, 25 Apr 2015 08:42:38 -0400 Subject: [PATCH 066/110] update the deprecated CountMinSketchMonoid function to TopPctCMS function http://twitter.github.io/algebird/index.html#com.twitter.algebird.legacy.CountMinSketchMonoid$ The CountMinSketchMonoid has been deprecated since 0.8.1. Newer code should use TopPctCMS.monoid(). ![image](https://cloud.githubusercontent.com/assets/1327396/7269619/d8b48b92-e8d5-11e4-8902-087f630e6308.png) Author: KeheCAI Closes #5629 from caikehe/master and squashes the following commits: e8aa06f [KeheCAI] update algebird-core to version 0.9.0 from 0.8.1 5653351 [KeheCAI] change scala code style 4c0dfd1 [KeheCAI] update the deprecated CountMinSketchMonoid function to TopPctCMS function --- examples/pom.xml | 2 +- .../apache/spark/examples/streaming/TwitterAlgebirdCMS.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/pom.xml b/examples/pom.xml index df1717403b673..5b04b4f8d6ca0 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -245,7 +245,7 @@ com.twitter algebird-core_${scala.binary.version} - 0.8.1 + 0.9.0 org.scalacheck diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index 62f49530edb12..c10de84a80ffe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming import com.twitter.algebird._ +import com.twitter.algebird.CMSHasherImplicits._ import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ @@ -67,7 +68,8 @@ object TwitterAlgebirdCMS { val users = stream.map(status => status.getUser.getId) - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) var globalCMS = cms.zero val mm = new MapMonoid[Long, Int]() var globalExact = Map[Long, Int]() From a61d65fc8b97c01be0fa756b52afdc91c46a8561 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 25 Apr 2015 10:37:34 -0700 Subject: [PATCH 067/110] Revert "[SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext" This reverts commit 534f2a43625fbf1a3a65d09550a19875cd1dce43. --- .../spark/api/java/function/Function0.java | 27 --- .../apache/spark/streaming/Checkpoint.scala | 26 +-- .../spark/streaming/StreamingContext.scala | 85 ++-------- .../api/java/JavaStreamingContext.scala | 119 +------------ .../apache/spark/streaming/JavaAPISuite.java | 145 ++++------------ .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ------------------ 7 files changed, 61 insertions(+), 503 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java deleted file mode 100644 index 38e410c5debe6..0000000000000 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; - -/** - * A zero-argument function that returns an R. - */ -public interface Function0 extends Serializable { - public R call() throws Exception; -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7bfae253c3a0c..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,8 +77,7 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { - + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -86,7 +85,6 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -162,7 +160,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -236,24 +234,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - /** - * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint - * files, then return None, else try to return the latest valid checkpoint object. If no - * checkpoint files could be read correctly, then return None (if ignoreReadError = true), - * or throw exception (if ignoreReadError = false). - */ - def read( - checkpointDir: String, - conf: SparkConf, - hadoopConf: Configuration, - ignoreReadError: Boolean = false): Option[Checkpoint] = { + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = + { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse if (checkpointFiles.isEmpty) { return None } @@ -293,10 +282,7 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") - } - None + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 90c8b47aebce0..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,19 +107,6 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) - /** - * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. - * @param path Path to the directory that was specified as the checkpoint directory - * @param sparkContext Existing SparkContext - */ - def this(path: String, sparkContext: SparkContext) = { - this( - sparkContext, - CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, - null) - } - - if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -128,12 +115,10 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (sc_ != null) { - sc_ - } else if (isCheckpointPresent) { + if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - throw new SparkException("Cannot create StreamingContext without a SparkContext") + sc_ } } @@ -144,7 +129,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = sc.env + private[streaming] val env = SparkEnv.get private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -189,9 +174,7 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - assert(env != null) - assert(env.metricsSystem != null) - env.metricsSystem.registerSource(streamingSource) + SparkEnv.get.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -638,59 +621,19 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, new SparkConf(), hadoopConf, createOnError) + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext - ): StreamingContext = { - getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new StreamingContext if there is an - * error in reading checkpoint data. By default, an exception will be - * thrown on error. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext, - createOnError: Boolean - ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) - checkpointOption.map(new StreamingContext(sparkContext, _, null)) - .getOrElse(creatingFunc(sparkContext)) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 572d7d8e8753d..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,14 +32,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.receiver.Receiver /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -656,7 +655,6 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -678,7 +676,6 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -703,7 +700,6 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -716,117 +712,6 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext] - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf, createOnError) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc, createOnError) - new JavaStreamingContext(ssc) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index cb2e8380b4933..90340753a4eed 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,12 +22,10 @@ import java.nio.charset.Charset; import java.util.*; -import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; - import scala.Tuple2; import org.junit.Assert; @@ -47,7 +45,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -932,7 +929,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -990,12 +987,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1126,7 +1123,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1147,14 +1144,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1252,17 +1249,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1295,17 +1292,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1331,7 +1328,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1710,74 +1707,6 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } - @SuppressWarnings("unchecked") - @Test - public void testContextGetOrCreate() throws InterruptedException { - - final SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("newContext", "true"); - - File emptyDir = Files.createTempDir(); - emptyDir.deleteOnExit(); - StreamingContextSuite contextSuite = new StreamingContextSuite(); - String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); - String checkpointDir = contextSuite.createValidCheckpoint(); - - // Function to create JavaStreamingContext without any output operations - // (used to detect the new context) - final MutableBoolean newContextCreated = new MutableBoolean(false); - Function0 creatingFunc = new Function0() { - public JavaStreamingContext call() { - newContextCreated.setValue(true); - return new JavaStreamingContext(conf, Seconds.apply(1)); - } - }; - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.setValue(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,8 +430,9 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written + val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4f193322ad33e..58353a5f97c8a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.streaming -import java.io.File import java.util.concurrent.atomic.AtomicInteger -import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -332,139 +330,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } - test("getOrCreate") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(): StreamingContext = { - newContextCreated = true - new StreamingContext(conf, batchDuration) - } - - // Call ssc.stop after a body of code - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop() - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val checkpointPath = createValidCheckpoint() - - // getOrCreate should recover context with checkpoint path, and recover old configuration - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) - } - - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") - } - } - test("DStream and generated RDD creation sites") { testPackage.test() } @@ -474,30 +339,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } - - def createValidCheckpoint(): String = { - val testDirectory = Utils.createTempDir().getAbsolutePath() - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("someKey", "someValue") - ssc = new StreamingContext(conf, batchDuration) - ssc.checkpoint(checkpointDirectory) - ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } - ssc.start() - eventually(timeout(10000 millis)) { - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) - } - ssc.stop() - checkpointDirectory - } - - def createCorruptedCheckpoint(): String = { - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) - FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) - checkpointDirectory - } } class TestException(msg: String) extends Exception(msg) From a7160c4e3aae22600d05e257d0b4d2428754b8ea Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 25 Apr 2015 12:27:19 -0700 Subject: [PATCH 068/110] [SPARK-6113] [ML] Tree ensembles for Pipelines API This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees. Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions. This PR follows the example set by the previous PR for Decision Trees. It includes a few cleanups to Decision Trees. Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap. I plan to submit a separate PR which makes those values in Model be Options. It does not matter much which PR gets merged first. CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits: 729167a [Joseph K. Bradley] small cleanups based on code review bbae2a2 [Joseph K. Bradley] Updated per all comments in code review 855aa9a [Joseph K. Bradley] scala style fix ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well d045ebd [Joseph K. Bradley] some more updates, but far from done ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates. --- .../examples/ml/DecisionTreeExample.scala | 139 ++++++---- .../apache/spark/examples/ml/GBTExample.scala | 238 +++++++++++++++++ .../examples/ml/RandomForestExample.scala | 248 +++++++++++++++++ .../mllib/GradientBoostedTreesRunner.scala | 1 + .../scala/org/apache/spark/ml/Model.scala | 2 + .../DecisionTreeClassifier.scala | 24 +- .../ml/classification/GBTClassifier.scala | 228 ++++++++++++++++ .../RandomForestClassifier.scala | 185 +++++++++++++ .../spark/ml/impl/tree/treeParams.scala | 249 +++++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 20 ++ .../ml/regression/DecisionTreeRegressor.scala | 14 +- .../spark/ml/regression/GBTRegressor.scala | 218 +++++++++++++++ .../ml/regression/RandomForestRegressor.scala | 167 ++++++++++++ .../scala/org/apache/spark/ml/tree/Node.scala | 6 +- .../org/apache/spark/ml/tree/Split.scala | 22 +- .../org/apache/spark/ml/tree/treeModels.scala | 46 +++- .../JavaDecisionTreeClassifierSuite.java | 10 +- .../JavaGBTClassifierSuite.java | 100 +++++++ .../JavaRandomForestClassifierSuite.java | 103 ++++++++ .../JavaDecisionTreeRegressorSuite.java | 26 +- .../ml/regression/JavaGBTRegressorSuite.java | 99 +++++++ .../JavaRandomForestRegressorSuite.java | 102 +++++++ .../DecisionTreeClassifierSuite.scala | 2 +- .../classification/GBTClassifierSuite.scala | 136 ++++++++++ .../RandomForestClassifierSuite.scala | 166 ++++++++++++ .../org/apache/spark/ml/impl/TreeTests.scala | 10 +- .../DecisionTreeRegressorSuite.scala | 2 +- .../ml/regression/GBTRegressorSuite.scala | 137 ++++++++++ .../RandomForestRegressorSuite.scala | 122 +++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 +- 31 files changed, 2658 insertions(+), 174 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 2cd515c89d3d2..9002e99d82ad3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -22,10 +22,9 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.ml.tree.DecisionTreeModel import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} @@ -64,8 +63,6 @@ object DecisionTreeExample { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, - numTrees: Int = 1, - featureSubsetStrategy: String = "auto", fracTest: Double = 0.2, cacheNodeIds: Boolean = false, checkpointDir: Option[String] = None, @@ -123,8 +120,8 @@ object DecisionTreeExample { .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.fracTest < 0 || params.fracTest > 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") } else { success } @@ -200,9 +197,18 @@ object DecisionTreeExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } } - val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + val dataframes = splits.map(_.toDF()).map(labelsToStrings) + val training = dataframes(0).cache() + val test = dataframes(1).cache() - (dataframes(0), dataframes(1)) + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + (training, test) } def run(params: Params) { @@ -217,13 +223,6 @@ object DecisionTreeExample { val (training: DataFrame, test: DataFrame) = loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) - val numTraining = training.count() - val numTest = test.count() - val numFeatures = training.select("features").first().getAs[Vector](0).size - println("Loaded data:") - println(s" numTraining = $numTraining, numTest = $numTest") - println(s" numFeatures = $numFeatures") - // Set up Pipeline val stages = new mutable.ArrayBuffer[PipelineStage]() // (1) For classification, re-index classes. @@ -241,7 +240,7 @@ object DecisionTreeExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn DecisionTree + // (3) Learn Decision Tree val dt = algo match { case "classification" => new DecisionTreeClassifier() @@ -275,62 +274,86 @@ object DecisionTreeExample { println(s"Training time: $elapsedTime seconds") // Get the trained Decision Tree from the fitted PipelineModel - val treeModel: DecisionTreeModel = algo match { + algo match { case "classification" => - pipelineModel.getModel[DecisionTreeClassificationModel]( + val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel]( dt.asInstanceOf[DecisionTreeClassifier]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } case "regression" => - pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) - case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - if (treeModel.numNodes < 20) { - println(treeModel.toDebugString) // Print full model. - } else { - println(treeModel) // Print model summary. - } - - // Predict on training - val trainingFullPredictions = pipelineModel.transform(training).cache() - val trainingPredictions = trainingFullPredictions.select("prediction") - .map(_.getDouble(0)) - val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) - // Predict on test data - val testFullPredictions = pipelineModel.transform(test).cache() - val testPredictions = testFullPredictions.select("prediction") - .map(_.getDouble(0)) - val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) - - // For classification, print number of classes for reference. - if (algo == "classification") { - val numClasses = - MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { - case Some(n) => n - case None => throw new RuntimeException( - "DecisionTreeExample had unknown failure when indexing labels for classification.") + val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel]( + dt.asInstanceOf[DecisionTreeRegressor]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. } - println(s"numClasses = $numClasses.") + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } // Evaluate model on training, test data algo match { case "classification" => - val trainingAccuracy = - new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision - println(s"Train accuracy = $trainingAccuracy") - val testAccuracy = - new MulticlassMetrics(testPredictions.zip(testLabels)).precision - println(s"Test accuracy = $testAccuracy") + println("Training data results:") + evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateClassificationModel(pipelineModel, test, labelColName) case "regression" => - val trainingRMSE = - new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError - println(s"Training root mean squared error (RMSE) = $trainingRMSE") - val testRMSE = - new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError - println(s"Test root mean squared error (RMSE) = $testRMSE") + println("Training data results:") + evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateRegressionModel(pipelineModel, test, labelColName) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } sc.stop() } + + /** + * Evaluate the given ClassificationModel on data. Print the results. + * @param model Must fit ClassificationModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateClassificationModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + // Print number of classes for reference + val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "Unknown failure when indexing labels for classification.") + } + val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision + println(s" Accuracy ($numClasses classes): $accuracy") + } + + /** + * Evaluate the given RegressionModel on data. Print the results. + * @param model Must fit RegressionModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to RegressionModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateRegressionModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError + println(s" Root mean squared error (RMSE): $RMSE") + } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala new file mode 100644 index 0000000000000..5fccb142d4c3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.GBTExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object GBTExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + maxIter: Int = 10, + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GBTExample") { + head("GBTExample: an example Gradient-Boosted Trees app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("maxIter") + .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"GBTExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"GBTExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn GBT + val dt = algo match { + case "classification" => + new GBTClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case "regression" => + new GBTRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained GBT from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala new file mode 100644 index 0000000000000..9b909324ec82a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.RandomForestExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomForestExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 10, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("RandomForestExample") { + head("RandomForestExample: an example random forest app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"number of features to use per node (supported:" + + s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"RandomForestExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"RandomForestExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn Random Forest + val dt = algo match { + case "classification" => + new RandomForestClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case "regression" => + new RandomForestRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Random Forest from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[RandomForestClassificationModel]( + dt.asInstanceOf[RandomForestClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[RandomForestRegressionModel]( + dt.asInstanceOf[RandomForestRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 431ead8c0c165..0763a7736305a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils + /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index cae5082b51196..a491bc7ee8295 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. + * Note: For ensembles' component Models, this value can be null. */ val parent: Estimator[M] /** * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + * Note: For ensembles' component Models, this value can be null. */ val fittingParamMap: ParamMap } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 3855e396b5534..ee2a8dc6db171 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeClassifier extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams - with TreeClassifierParams { + with DecisionTreeParams with TreeClassifierParams { // Override parameter setters from parent trait for Java API compatibility. @@ -59,11 +58,9 @@ final class DecisionTreeClassifier override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) - override def setCacheNodeIds(value: Boolean): this.type = - super.setCacheNodeIds(value) + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -75,8 +72,9 @@ final class DecisionTreeClassifier val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + - s" with invalid label column, without the number of classes specified.") - // TODO: Automatically index labels. + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 } val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) val strategy = getOldStrategy(categoricalFeatures, numClasses) @@ -85,18 +83,16 @@ final class DecisionTreeClassifier } /** (private[ml]) Create a Strategy instance to use with the old API. */ - override private[ml] def getOldStrategy( + private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses) - strategy.algo = OldAlgo.Classification - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeClassifier { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala new file mode 100644 index 0000000000000..d2e052fbbbf22 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Param, Params, ParamMap} +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + */ +@AlphaComponent +final class GBTClassifier + extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + with GBTParams with TreeClassifierParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTClassifier.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTClassifier: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "logistic") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" + + s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("GBTClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + require(numClasses == 2, + s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTClassifier { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTClassificationModel( + override val parent: GBTClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model: SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTClassificationModel = { + val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala new file mode 100644 index 0000000000000..cfd6508fce890 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class RandomForestClassifier + extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + with RandomForestParams with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestClassifier { + /** Accessor for supported impurity settings: entropy, gini */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + * @param _trees Decision trees in the ensemble. + * Warning: These have null parents. + */ +@AlphaComponent +final class RandomForestClassificationModel private[ml] ( + override val parent: RandomForestClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeClassificationModel]) + extends PredictionModel[Vector, RandomForestClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Classifies using majority votes. + // Ignore the weights since all are 1.0 for now. + val votes = mutable.Map.empty[Int, Double] + _trees.view.foreach { tree => + val prediction = tree.rootNode.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + } + votes.maxBy(_._2)._1 + } + + override protected def copy(): RandomForestClassificationModel = { + val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestClassificationModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index eb2609faef05a..ab6281b9b2e34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.impl.estimator.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, + BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** @@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this } /** @group getParam */ - def getMaxDepth: Int = getOrDefault(maxDepth) + final def getMaxDepth: Int = getOrDefault(maxDepth) /** @group setParam */ def setMaxBins(value: Int): this.type = { require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") set(maxBins, value) - this } /** @group getParam */ - def getMaxBins: Int = getOrDefault(maxBins) + final def getMaxBins: Int = getOrDefault(maxBins) /** @group setParam */ def setMinInstancesPerNode(value: Int): this.type = { require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") set(minInstancesPerNode, value) - this } /** @group getParam */ - def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) /** @group setParam */ - def setMinInfoGain(value: Double): this.type = { - set(minInfoGain, value) - this - } + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group getParam */ - def getMinInfoGain: Double = getOrDefault(minInfoGain) + final def getMinInfoGain: Double = getOrDefault(minInfoGain) /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = { require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") set(maxMemoryInMB, value) - this } /** @group expertGetParam */ - def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) /** @group expertSetParam */ - def setCacheNodeIds(value: Boolean): this.type = { - set(cacheNodeIds, value) - this - } + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** @group expertGetParam */ - def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) /** @group expertSetParam */ def setCheckpointInterval(value: Int): this.type = { require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") set(checkpointInterval, value) - this } /** @group expertGetParam */ - def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) - /** - * Create a Strategy instance to use with the old API. - * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, - * the default for single trees). - */ + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], - numClasses: Int): OldStrategy = { - val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStategy(oldAlgo) + strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval strategy.maxBins = getMaxBins strategy.maxDepth = getMaxDepth @@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams { strategy.useNodeIdCache = getCacheNodeIds strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures - strategy.subsamplingRate = 1.0 // default for individual trees + strategy.subsamplingRate = subsamplingRate strategy } } /** - * (private trait) Parameters for Decision Tree-based classification algorithms. + * Parameters for Decision Tree-based classification algorithms. */ private[ml] trait TreeClassifierParams extends Params { @@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params { * (default = gini) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") @@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params { s"Tree-based classifier was given unrecognized impurity: $value." + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } /** - * (private trait) Parameters for Decision Tree-based regression algorithms. + * Parameters for Decision Tree-based regression algorithms. */ private[ml] trait TreeRegressorParams extends Params { @@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params { * (default = variance) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") @@ -276,11 +267,10 @@ private[ml] trait TreeRegressorParams extends Params { s"Tree-based regressor was given unrecognized impurity: $value." + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based ensemble algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { + + /** + * Fraction of the training data used for learning each decision tree. + * (default = 1.0) + * @group param + */ + final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", + "Fraction of the training data used for learning each decision tree.") + + setDefault(subsamplingRate -> 1.0) + + /** @group setParam */ + def setSubsamplingRate(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"Subsampling rate must be in range (0,1]. Bad rate: $value") + set(subsamplingRate, value) + } + + /** @group getParam */ + final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and seed. + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +/** + * :: DeveloperApi :: + * Parameters for Random Forest algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)") + + /** + * The number of features to consider for splits at each tree node. + * Supported options: + * - "auto": Choose automatically for task: + * If numTrees == 1, set to "all." + * If numTrees > 1 (forest), set to "sqrt" for classification and + * to "onethird" for regression. + * - "all": use all features + * - "onethird": use 1/3 of the features + * - "sqrt": use sqrt(number of features) + * - "log2": use log2(number of features) + * (default = "auto") + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @group param + */ + final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node." + + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + + setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + + /** @group setParam */ + def setNumTrees(value: Int): this.type = { + require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.") + set(numTrees, value) + } + + /** @group getParam */ + final def getNumTrees: Int = getOrDefault(numTrees) + + /** @group setParam */ + def setFeatureSubsetStrategy(value: String): this.type = { + val strategyStr = value.toLowerCase + require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr), + s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" + + s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + set(featureSubsetStrategy, strategyStr) + } + + /** @group getParam */ + final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy) +} + +private[ml] object RandomForestParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Gradient-Boosted Tree algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { + + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator") + + /* TODO: Add this doc when we add this param. SPARK-7132 + * Threshold for stopping early when runWithValidation is used. + * If the error rate on the validation input changes by less than the validationTol, + * then learning will stop early (before [[numIterations]]). + * This parameter is ignored when run is used. + * (default = 1e-5) + * @group param + */ + // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") + // validationTol -> 1e-5 + + setDefault(maxIter -> 20, stepSize -> 0.1) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = { + require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.") + set(maxIter, value) + } + + /** @group setParam */ + def setStepSize(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"GBT given invalid step size ($value). Value should be in (0,1].") + set(stepSize, value) + } + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) + + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ + private[ml] def getOldBoostingStrategy( + categoricalFeatures: Map[Int, Int], + oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + } + + /** Get old Gradient Boosting Loss type */ + private[ml] def getOldLossType: OldLoss } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 95d7e64790c79..e88c48741e99f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), - ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" @@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen { | |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ + |import org.apache.spark.util.Utils | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. | diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 72b08bf276483..a860b8834cff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ +import org.apache.spark.util.Utils // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -256,4 +257,23 @@ trait HasFitIntercept extends Params { /** @group getParam */ final def getFitIntercept: Boolean = getOrDefault(fitIntercept) } + +/** + * :: DeveloperApi :: + * Trait for shared param seed (default: Utils.random.nextLong()). + */ +@DeveloperApi +trait HasSeed extends Params { + + /** + * Param for random seed. + * @group param + */ + final val seed: LongParam = new LongParam(this, "seed", "random seed") + + setDefault(seed, Utils.random.nextLong()) + + /** @group getParam */ + final def getSeed: Long = getOrDefault(seed) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 49a8b77acf960..756725a64b0f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeRegressor extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams - with TreeRegressorParams { + with DecisionTreeParams with TreeRegressorParams { // Override parameter setters from parent trait for Java API compatibility. @@ -60,8 +59,7 @@ final class DecisionTreeRegressor override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -78,15 +76,13 @@ final class DecisionTreeRegressor /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) - strategy.algo = OldAlgo.Regression - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeRegressor { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala new file mode 100644 index 0000000000000..c784cf39ed31a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap, Param} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, + SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class GBTRegressor + extends Predictor[Vector, GBTRegressor, GBTRegressionModel] + with GBTParams with TreeRegressorParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTRegressor.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTRegressor: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "squared") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" + + s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTRegressor { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTRegressionModel( + override val parent: GBTRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTRegressionModel = { + val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala new file mode 100644 index 0000000000000..2171ef3d32c26 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class RandomForestRegressor + extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] + with RandomForestParams with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestRegressor { + /** Accessor for supported impurity settings: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + */ +@AlphaComponent +final class RandomForestRegressionModel private[ml] ( + override val parent: RandomForestRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel]) + extends PredictionModel[Vector, RandomForestRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // Predict average of tree predictions. + // Ignore the weights since all are 1.0 for now. + _trees.map(_.rootNode.predict(features)).sum / numTrees + } + + override protected def copy(): RandomForestRegressionModel = { + val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestRegressionModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d6e2203d9f937..d2dec0c76cb12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree - // code into the new API and deprecate the old API. + // code into the new API and deprecate the old API. SPARK-3727 - /** Prediction this node makes (or would make, if it is an internal node) */ + /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ def prediction: Double /** Impurity measure at this node (for training data) */ @@ -194,7 +194,7 @@ private object InternalNode { s"$featureStr > ${contSplit.threshold}" } case catSplit: CategoricalSplit => - val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}") if (left) { s"$featureStr in $categoriesStr" } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 708c769087dd0..90f1d052764d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -44,7 +44,7 @@ private[tree] object Split { oldSplit.featureType match { case OldFeatureType.Categorical => new CategoricalSplit(featureIndex = oldSplit.feature, - leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) case OldFeatureType.Continuous => new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) } @@ -54,30 +54,30 @@ private[tree] object Split { /** * Split which tests a categorical feature. * @param featureIndex Index of the feature to test - * @param leftCategories If the feature value is in this set of categories, then the split goes - * left. Otherwise, it goes right. + * @param _leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ final class CategoricalSplit private[ml] ( override val featureIndex: Int, - leftCategories: Array[Double], + _leftCategories: Array[Double], private val numCategories: Int) extends Split { - require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + - s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}") /** * If true, then "categories" is the set of categories for splitting to the left, and vice versa. */ - private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + private val isLeft: Boolean = _leftCategories.length <= numCategories / 2 /** Set of categories determining the splitting rule, along with [[isLeft]]. */ private val categories: Set[Double] = { if (isLeft) { - leftCategories.toSet + _leftCategories.toSet } else { - setComplement(leftCategories.toSet) + setComplement(_leftCategories.toSet) } } @@ -107,13 +107,13 @@ final class CategoricalSplit private[ml] ( } /** Get sorted categories which split to the left */ - def getLeftCategories: Array[Double] = { + def leftCategories: Array[Double] = { val cats = if (isLeft) categories else setComplement(categories) cats.toArray.sorted } /** Get sorted categories which split to the right */ - def getRightCategories: Array[Double] = { + def rightCategories: Array[Double] = { val cats = if (isLeft) setComplement(categories) else categories cats.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 8e3bc3849dcf0..1929f9d02156e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.AlphaComponent - /** - * :: AlphaComponent :: - * * Abstraction for Decision Tree models. * - * TODO: Add support for predicting probabilities and raw predictions + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 */ -@AlphaComponent -trait DecisionTreeModel { +private[ml] trait DecisionTreeModel { /** Root of the decision tree */ def rootNode: Node @@ -58,3 +53,40 @@ trait DecisionTreeModel { header + rootNode.subtreeToString(2) } } + +/** + * Abstraction for models which are ensembles of decision trees + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[ml] trait TreeEnsembleModel { + + // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of + // DecisionTreeModel. + + /** Trees in this ensemble. Warning: These have null parent Estimators. */ + def trees: Array[DecisionTreeModel] + + /** Weights for each tree, zippable with [[trees]] */ + def treeWeights: Array[Double] + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"TreeEnsembleModel with $numTrees trees" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) => + s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + + /** Total number of nodes, summed over all trees in the ensemble. */ + lazy val totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 43b8787f9dd7e..60f25e5cce437 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.classification; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +55,7 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -71,8 +69,8 @@ public void runDT() { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java new file mode 100644 index 0000000000000..3c69467fa119e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + GBTClassifier rf = new GBTClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTClassifier.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java new file mode 100644 index 0000000000000..32d0b3856b7e2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + RandomForestClassifier rf = new RandomForestClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestClassifier.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + RandomForestClassificationModel sameModel = + RandomForestClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index a3a339004f31c..71b041818d7ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.regression; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,22 +55,22 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setMaxDepth(2) - .setMaxBins(10) - .setMinInstancesPerNode(5) - .setMinInfoGain(0.0) - .setMaxMemoryInMB(256) - .setCacheNodeIds(false) - .setCheckpointInterval(10) - .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java new file mode 100644 index 0000000000000..fc8c13db07e6f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + GBTRegressor rf = new GBTRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTRegressor.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java new file mode 100644 index 0000000000000..e306ebadfe7cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + // This tests setters. Training with various options is tested in Scala. + RandomForestRegressor rf = new RandomForestRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestRegressor.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index af88595df5245..9b31adecdcb1c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented + // TODO: Reinstate test once save/load are implemented SPARK-6725 /* test("model save/load") { val tempDir = Utils.createTempDir() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala new file mode 100644 index 0000000000000..e6ccc2c93cba8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTClassifier]]. + */ +class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import GBTClassifierSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Binary classification with continuous features: Log Loss") { + val categoricalFeatures = Map.empty[Int, Int] + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType("logistic") + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTClassifier.supportedLossTypes.foreach { loss => + val gbt = new GBTClassifier() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) + val newModel = GBTClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTClassifier, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = + gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala new file mode 100644 index 0000000000000..ed41a9664f94f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestClassifier]]. + */ +class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestClassifierSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + orderedLabeledPoints5_20 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) { + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val newRF = rf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("Binary classification with continuous features and node Id cache:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + .setCacheNodeIds(true) + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("sqrt") + .setSeed(12345) + compareAPIs(rdd, rf, categoricalFeatures, numClasses) + } + + test("subsampling rate in RandomForest"){ + val rdd = orderedLabeledPoints5_20 + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val rf1 = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(rdd, rf1, categoricalFeatures, numClasses) + + val rf2 = rf1.setSubsamplingRate(0.5) + compareAPIs(rdd, rf2, categoricalFeatures, numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = + Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) + val newModel = RandomForestClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 2e57d4ce37f1d..1505ad872536b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -23,8 +23,7 @@ import org.scalatest.FunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} -import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} @@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite { } } - // TODO: Reinstate after adding ensembles /** * Check if the two models are exactly the same. * If the models are not equal, this throws an exception. */ - /* def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { try { - a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + a.trees.zip(b.trees).foreach { case (treeA, treeB) => TreeTests.checkEqual(treeA, treeB) } - assert(a.getTreeWeights === b.getTreeWeights) + assert(a.treeWeights === b.treeWeights) } catch { case ex: Exception => throw new AssertionError( "checkEqual failed since the two tree ensembles were not identical") } } - */ } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0b40fe33fae9d..c87a171b4b229 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") + // TODO: test("model save/load") SPARK-6725 } private[ml] object DecisionTreeRegressorSuite extends FunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala new file mode 100644 index 0000000000000..4aec36948ac92 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTRegressor]]. + */ +class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import GBTRegressorSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Regression with continuous features: SquaredError") { + val categoricalFeatures = Map.empty[Int, Int] + GBTRegressor.supportedLossTypes.foreach { loss => + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType(loss) + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTRegressor.supportedLossTypes.foreach { loss => + val gbt = new GBTRegressor() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) + val newModel = GBTRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTRegressorSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala new file mode 100644 index 0000000000000..c6dc1cc29b6ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestRegressor]]. + */ +class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestRegressorSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val newRF = rf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + regressionTestWithContinuousFeatures(rf) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + .setCacheNodeIds(true) + regressionTestWithContinuousFeatures(rf) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) + val newModel = RandomForestRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestRegressorSuite extends FunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 249b8eae19b17..ce983eb27fa35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite { node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, categories = List(0.0, 1.0))) } - // TODO: The information gain stats should be consistent with the same info stored in children. + // TODO: The information gain stats should be consistent with info in children: SPARK-7131 node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) node @@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. - * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) + * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131 */ - private[mllib] def createModel(algo: Algo): DecisionTreeModel = { + private[spark] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) From aa6966ff34dacc83c3ca675b5109b05e35015469 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 25 Apr 2015 13:43:39 -0700 Subject: [PATCH 069/110] [SQL] Update SQL readme to include instructions on generating golden answer files based on Hive 0.13.1. Author: Yin Huai Closes #5702 from yhuai/howToGenerateGoldenFiles and squashes the following commits: 9c4a7f8 [Yin Huai] Update readme to include instructions on generating golden answer files based on Hive 0.13.1. --- sql/README.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 237620e3fa808..46aec7cef7984 100644 --- a/sql/README.md +++ b/sql/README.md @@ -12,7 +12,10 @@ Spark SQL is broken up into four subprojects: Other dependencies for developers --------------------------------- -In order to create new hive test cases , you will need to set several environmental variables. +In order to create new hive test cases (i.e. a test suite based on `HiveComparisonTest`), +you will need to setup your development environment based on the following instructions. + +If you are working with Hive 0.12.0, you will need to set several environmental variables as follows. ``` export HIVE_HOME="/hive/build/dist" @@ -20,6 +23,24 @@ export HIVE_DEV_HOME="/hive/" export HADOOP_HOME="/hadoop-1.0.4" ``` +If you are working with Hive 0.13.1, the following steps are needed: + +1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). +2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` +3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. +4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. +5. This step is optional. But, when generating golden answer files, if a Hive query fails and you find that Hive tries to talk to HDFS or you find weird runtime NPEs, set the following in your test suite... + +``` +val testTempDir = Utils.createTempDir() +// We have to use kryo to let Hive correctly serialize some plans. +sql("set hive.plan.serialization.format=kryo") +// Explicitly set fs to local fs. +sql(s"set fs.default.name=file://$testTempDir/") +// Ask Hive to run jobs in-process as a single map and reduce task. +sql("set mapred.job.tracker=local") +``` + Using the console ================= An interactive scala console can be invoked by running `build/sbt hive/console`. From a11c8683c76c67f45749a1b50a0912a731fd2487 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sat, 25 Apr 2015 18:07:34 -0400 Subject: [PATCH 070/110] [SPARK-7092] Update spark scala version to 2.11.6 Author: Prashant Sharma Closes #5662 from ScrapCodes/SPARK-7092/scala-update-2.11.6 and squashes the following commits: 58cf4f9 [Prashant Sharma] [SPARK-7092] Update spark scala version to 2.11.6 --- pom.xml | 4 ++-- .../src/main/scala/org/apache/spark/repl/SparkIMain.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4b0b0c85eff21..9fbce1d639d8b 100644 --- a/pom.xml +++ b/pom.xml @@ -1745,9 +1745,9 @@ scala-2.11 - 2.11.2 + 2.11.6 2.11 - 2.12 + 2.12.1 jline diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1bb62c84abddc..1cb910f376060 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1129,7 +1129,7 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings def apply(line: String): Result = debugging(s"""parse("$line")""") { var isIncomplete = false - currentRun.reporting.withIncompleteHandler((_, _) => isIncomplete = true) { + currentRun.parsing.withIncompleteHandler((_, _) => isIncomplete = true) { reporter.reset() val trees = newUnitParser(line).parseStats() if (reporter.hasErrors) Error From f5473c2bbf66cc1144a90b4c29f3ce54ad7cc419 Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sat, 25 Apr 2015 20:02:23 -0400 Subject: [PATCH 071/110] [SPARK-6014] [CORE] [HOTFIX] Add try-catch block around ShutDownHook Add a try/catch block around removeShutDownHook else IllegalStateException thrown in YARN cluster mode (see https://github.com/apache/spark/pull/4690) cc andrewor14, srowen Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #5672 from nishkamravi2/master_nravi and squashes the following commits: 0f1abd0 [nishkamravi2] Update Utils.scala 474e3bf [nishkamravi2] Update DiskBlockManager.scala 97c383e [nishkamravi2] Update Utils.scala 8691e0c [Nishkam Ravi] Add a try/catch block around Utils.removeShutdownHook 2be1e76 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 1c13b79 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 7 ++++++- core/src/main/scala/org/apache/spark/util/Utils.scala | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 7ea5e54f9e1fe..5764c16902c66 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -148,7 +148,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. - Utils.removeShutdownHook(shutdownHook) + try { + Utils.removeShutdownHook(shutdownHook) + } catch { + case e: Exception => + logError(s"Exception while removing shutdown hook.", e) + } doStop() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 667aa168e7ef3..c6c6df7cfa56e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2172,7 +2172,7 @@ private [util] class SparkShutdownHookManager { def runAll(): Unit = synchronized { shuttingDown = true while (!hooks.isEmpty()) { - Utils.logUncaughtExceptions(hooks.poll().run()) + Try(Utils.logUncaughtExceptions(hooks.poll().run())) } } @@ -2184,7 +2184,6 @@ private [util] class SparkShutdownHookManager { } def remove(ref: AnyRef): Boolean = synchronized { - checkState() hooks.remove(ref) } From 9a5bbe05fc1b1141e139d32661821fef47d7a13c Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 26 Apr 2015 07:14:24 -0400 Subject: [PATCH 072/110] [MINOR] [MLLIB] Refactor toString method in MLLIB MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. predict(predict.toString) has already output prefix “predict” thus it’s duplicated to print ", predict = " again 2. there are some extra spaces Author: Alain Closes #5687 from AiHe/tree-node-issue-2 and squashes the following commits: 9862b9a [Alain] Pass scala coding style checking 44ba947 [Alain] Minor][MLLIB] Format toString method in MLLIB bdc402f [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node 426eee7 [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node.scala --- .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/mllib/regression/LabeledPoint.scala | 2 +- .../apache/spark/mllib/tree/model/InformationGainStats.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/model/Predict.scala | 4 +--- .../main/scala/org/apache/spark/mllib/tree/model/Split.scala | 5 ++--- 6 files changed, 9 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4ef171f4f0419..166c00cff634d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -526,7 +526,7 @@ class SparseVector( s" ${values.size} values.") override def toString: String = - "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) + s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" override def toArray: Array[Double] = { val data = new Array[Double](size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 2067b36f246b3..d5fea822ad77b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkException @BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "(%s,%s)".format(label, features) + s"($label,$features)" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f209fdafd3653..2d087c967f679 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -39,8 +39,8 @@ class InformationGainStats( val rightPredict: Predict) extends Serializable { override def toString: String = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" - .format(gain, impurity, leftImpurity, rightImpurity) + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" } override def equals(o: Any): Boolean = o match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 86390a20cb5cc..431a839817eac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -51,8 +51,8 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { - "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + ", split = " + split + ", stats = " + stats + s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + + s"split = $split, stats = $stats" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 25990af7c6cf7..5cbe7c280dbee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -29,9 +29,7 @@ class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { - override def toString: String = { - "predict = %f, prob = %f".format(predict, prob) - } + override def toString: String = s"$predict (prob = $prob)" override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index fb35e70a8d077..be6c9b3de5479 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -39,8 +39,8 @@ case class Split( categories: List[Double]) { override def toString: String = { - "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + - ", categories = " + categories + s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + + s"categories = $categories" } } @@ -68,4 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) */ private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - From ca55dc95b777d96b27d4e4c0457dd25145dcd6e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Apr 2015 11:46:58 -0700 Subject: [PATCH 073/110] [SPARK-7152][SQL] Add a Column expression for partition ID. Author: Reynold Xin Closes #5705 from rxin/df-pid and squashes the following commits: 401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID. --- python/pyspark/sql/functions.py | 30 +++++++++----- .../expressions/SparkPartitionID.scala | 39 +++++++++++++++++++ .../sql/execution/expressions/package.scala | 23 +++++++++++ .../org/apache/spark/sql/functions.scala | 29 +++++++++----- .../spark/sql/ColumnExpressionSuite.scala | 8 ++++ 5 files changed, 110 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bb47923f24b82..f48b7b5d10af7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -75,6 +75,20 @@ def _(col): __all__.sort() +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -89,18 +103,16 @@ def countDistinct(col, *cols): return Column(jc) -def approxCountDistinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. +def sparkPartitionId(): + """Returns a column for partition ID of the Spark task. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] + Note that this is indeterministic because it depends on data partitioning and task scheduling. + + >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) + return Column(sc._jvm.functions.sparkPartitionId()) class UserDefinedFunction(object): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala new file mode 100644 index 0000000000000..fe7607c6ac340 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{IntegerType, DataType} + + +/** + * Expression that returns the current partition id of the Spark task. + */ +case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { + self: Product => + + override type EvaluatedType = Int + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override def eval(input: Row): Int = TaskContext.get().partitionId() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala new file mode 100644 index 0000000000000..568b7ac2c5987 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Package containing expressions that are specific to Spark runtime. + */ +package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91e1d74bc2c..9738fd4f93bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -276,6 +276,13 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + /** * Returns the first column that is not null. * {{{ @@ -287,6 +294,13 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + /** * Unary minus, i.e. negate the expression. * {{{ @@ -317,18 +331,13 @@ object functions { def not(e: Column): Column = !e /** - * Converts a string expression to upper case. + * Partition ID of the Spark task. * - * @group normal_funcs - */ - def upper(e: Column): Column = Upper(e.expr) - - /** - * Converts a string exprsesion to lower case. + * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs */ - def lower(e: Column): Column = Lower(e.expr) + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID /** * Computes the square root of the specified float value. @@ -338,11 +347,11 @@ object functions { def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Computes the absolutle value. + * Converts a string expression to upper case. * * @group normal_funcs */ - def abs(e: Column): Column = Abs(e.expr) + def upper(e: Column): Column = Upper(e.expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc8fae100db6a..904073b8cb2aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("sparkPartitionId") { + val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(sparkPartitionId()), + Row(0) + ) + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, From d188b8bad82836bf654e57f9dd4e1ddde1d530f4 Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 26 Apr 2015 21:08:47 -0700 Subject: [PATCH 074/110] [SQL][Minor] rename DataTypeParser.apply to DataTypeParser.parse rename DataTypeParser.apply to DataTypeParser.parse to make it more clear and readable. /cc rxin Author: wangfei Closes #5710 from scwf/apply and squashes the following commits: c319977 [wangfei] rename apply to parse --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- .../scala/org/apache/spark/sql/types/DataTypeParser.scala | 2 +- .../org/apache/spark/sql/types/DataTypeParserSuite.scala | 4 ++-- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..4574934d910db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -92,7 +92,7 @@ object PhysicalOperation extends PredicateHelper { } def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { - case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + case a @ Alias(child, _) => a.toAttribute -> child }.toMap def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 5163f05879e42..04f3379afb38d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -108,7 +108,7 @@ private[sql] object DataTypeParser { override val lexical = new SqlLexical } - def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) } /** The exception thrown from the [[DataTypeParser]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 169125264a803..3e7cf7cbb5e63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -23,13 +23,13 @@ class DataTypeParserSuite extends FunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser(dataTypeString) === expectedDataType) + assert(DataTypeParser.parse(dataTypeString) === expectedDataType) } } def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser(dataTypeString)) + intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index edb229c059e6b..33f9d0b37d006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -647,7 +647,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def cast(to: String): Column = cast(DataTypeParser(to)) + def cast(to: String): Column = cast(DataTypeParser.parse(to)) /** * Returns an ordering used in sorting. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f1c0bd92aa23d..4d222cf88e5e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -871,7 +871,7 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { - def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) + def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" From 82bb7fd41a2c7992e0aea69623c504bd439744f7 Mon Sep 17 00:00:00 2001 From: baishuo Date: Mon, 27 Apr 2015 14:08:05 +0800 Subject: [PATCH 075/110] [SPARK-6505] [SQL] Remove the reflection call in HiveFunctionWrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit according liancheng‘s comment in https://issues.apache.org/jira/browse/SPARK-6505, this patch remove the reflection call in HiveFunctionWrapper, and implement the functions named "deserializeObjectByKryo" and "serializeObjectByKryo" according the functions with the save name in org.apache.hadoop.hive.ql.exec.Utilities.java Author: baishuo Closes #5660 from baishuo/SPARK-6505-20150423 and squashes the following commits: ae61ec4 [baishuo] modify code style 78d9fa3 [baishuo] modify code style 0b522a7 [baishuo] modify code style a5ff9c7 [baishuo] Remove the reflection call in HiveFunctionWrapper --- .../org/apache/spark/sql/hive/Shim13.scala | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index d331c210e8939..dbc5e029e2047 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.hive import java.rmi.server.UID import java.util.{Properties, ArrayList => JArrayList} +import java.io.{OutputStream, InputStream} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst @@ -46,6 +50,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} +import org.apache.spark.util.Utils._ /** * This class provides the UDF creation and also the UDF instance serialization and @@ -61,39 +66,34 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import org.apache.spark.util.Utils._ - @transient - private val methodDeSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "deserializeObjectByKryo", - classOf[Kryo], - classOf[java.io.InputStream], - classOf[Class[_]]) - method.setAccessible(true) - - method + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] + inp.close() + t } @transient - private val methodSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "serializeObjectByKryo", - classOf[Kryo], - classOf[Object], - classOf[java.io.OutputStream]) - method.setAccessible(true) - - method + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream ) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() } def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) .asInstanceOf[UDFType] } def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) } private var instance: AnyRef = null From 998aac21f0a0588a70f8cf123ae4080163c612fb Mon Sep 17 00:00:00 2001 From: Misha Chernetsov Date: Mon, 27 Apr 2015 11:27:56 -0700 Subject: [PATCH 076/110] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script Author: Misha Chernetsov Closes #5429 from chernetsov/master and squashes the following commits: 9cc36af [Misha Chernetsov] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script for scala 2.10 --- dev/create-release/create-release.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b5a67dd783b93..3dbb35f7054a2 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -119,7 +119,7 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh From 7078f6028bf012235c664b02ec3541cbb0a248a7 Mon Sep 17 00:00:00 2001 From: Jeff Harrison Date: Mon, 27 Apr 2015 13:38:25 -0700 Subject: [PATCH 077/110] [SPARK-6856] [R] Make RDD information more useful in SparkR Author: Jeff Harrison Closes #5667 from His-name-is-Joof/joofspark and squashes the following commits: f8814a6 [Jeff Harrison] newline added after RDD show() output 4d9d972 [Jeff Harrison] Merge branch 'master' into joofspark 9d2295e [Jeff Harrison] parallelize with 1:10 878b830 [Jeff Harrison] Merge branch 'master' into joofspark c8c0b80 [Jeff Harrison] add test for RDD function show() 123be65 [Jeff Harrison] SPARK-6856 --- R/pkg/R/RDD.R | 5 +++++ R/pkg/inst/tests/test_rdd.R | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 1662d6bb3b1ac..f90c26b253455 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -66,6 +66,11 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) +setMethod("show", "RDD", + function(.Object) { + cat(paste(callJMethod(.Object@jrdd, "toString"), "\n", sep="")) + }) + setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { .Object@env <- new.env() .Object@env$isCached <- FALSE diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index d55af93e3e50a..03207353c31c6 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -759,6 +759,11 @@ test_that("collectAsMap() on a pairwise RDD", { expect_equal(vals, list(`1` = "a", `2` = "b")) }) +test_that("show()", { + rdd <- parallelize(sc, list(1:10)) + expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") +}) + test_that("sampleByKey() on pairwise RDDs", { rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) From ef82bddc11d1aea42e22d2f85613a869cbe9a990 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 27 Apr 2015 14:42:40 -0700 Subject: [PATCH 078/110] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat... ....py Author: tedyu Closes #5673 from tedyu/master and squashes the following commits: ab7c72b [tedyu] SPARK-7107 Adjust indentation to pass Python style tests 6e25939 [tedyu] Adjust line length to be shorter than 100 characters 18d172a [tedyu] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat.py --- examples/src/main/python/hbase_inputformat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index e17819d5feb76..5b82a14fba413 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -54,8 +54,9 @@ Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ - /path/to/examples/hbase_inputformat.py + /path/to/examples/hbase_inputformat.py
[] Assumes you have some data in HBase already, running on , in
+ optionally, you can specify parent znode for your hbase cluster - """, file=sys.stderr) exit(-1) @@ -64,6 +65,9 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + if len(sys.argv) > 3: + conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], + "hbase.mapreduce.inputtable": table} keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" From ca9f4ebb8e510e521bf4df0331375ddb385fb9d2 Mon Sep 17 00:00:00 2001 From: hlin09 Date: Mon, 27 Apr 2015 15:04:37 -0700 Subject: [PATCH 079/110] [SPARK-6991] [SPARKR] Adds support for zipPartitions. Author: hlin09 Closes #5568 from hlin09/zipPartitions and squashes the following commits: 12c08a5 [hlin09] Fix comments d2d32db [hlin09] Merge branch 'master' into zipPartitions ec56d2f [hlin09] Fix test. 27655d3 [hlin09] Adds support for zipPartitions. --- R/pkg/NAMESPACE | 1 + R/pkg/R/RDD.R | 46 +++++++++++++++++++++++++ R/pkg/R/generics.R | 5 +++ R/pkg/inst/tests/test_binary_function.R | 33 ++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 80283643861ac..e077eace74375 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -71,6 +71,7 @@ exportMethods( "unpersist", "value", "values", + "zipPartitions", "zipRDD", "zipWithIndex", "zipWithUniqueId" diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index f90c26b253455..a3a0421a0746d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -1595,3 +1595,49 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) + +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +setMethod("zipPartitions", + "RDD", + function(..., func) { + rrdds <- list(...) + if (length(rrdds) == 1) { + return(rrdds[[1]]) + } + nPart <- sapply(rrdds, numPartitions) + if (length(unique(nPart)) != 1) { + stop("Can only zipPartitions RDDs which have the same number of partitions.") + } + + rrdds <- lapply(rrdds, function(rdd) { + mapPartitionsWithIndex(rdd, function(partIndex, part) { + print(length(part)) + list(list(partIndex, part)) + }) + }) + union.rdd <- Reduce(unionRDD, rrdds) + zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1])) + res <- mapPartitions(zipped.rdd, function(plist) { + do.call(func, plist[[1]]) + }) + res + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 34dbe84051c50..e88729387ef95 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -217,6 +217,11 @@ setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) +#' @rdname zipRDD +#' @export +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, + signature = "...") + #' @rdname zipWithIndex #' @seealso zipWithUniqueId #' @export diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index c15553ba28517..6785a7bdae8cb 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -66,3 +66,36 @@ test_that("cogroup on two RDDs", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) + +test_that("zipPartitions() on RDDs", { + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 + rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 + rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + func = function(x, y, z) { list(list(x, y, z))} )) + expect_equal(actual, + list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipPartitions(rdd, rdd, + func = function(x, y) { list(paste(x, y, sep = "\n")) })) + expected <- list(paste(mockFile, mockFile, sep = "\n")) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipPartitions(rdd1, rdd, + func = function(x, y) { list(x + nchar(y)) })) + expected <- list(0:1 + nchar(mockFile)) + expect_equal(actual, expected) + + rdd <- map(rdd, function(x) { x }) + actual <- collect(zipPartitions(rdd, rdd1, + func = function(x, y) { list(y + nchar(x)) })) + expect_equal(actual, expected) + + unlink(fileName) +}) From b9de9e040aff371c6acf9b3f3d1ff8b360c0cd56 Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 27 Apr 2015 18:55:02 -0400 Subject: [PATCH 080/110] [SPARK-7103] Fix crash with SparkContext.union when RDD has no partitioner Added a check to the SparkContext.union method to check that a partitioner is defined on all RDDs when instantiating a PartitionerAwareUnionRDD. Author: Steven She Closes #5679 from stevencanopy/SPARK-7103 and squashes the following commits: 5a3d846 [Steven She] SPARK-7103: Fix crash with SparkContext.union when at least one RDD has no partitioner --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/rdd/PartitionerAwareUnionRDD.scala | 1 + .../scala/org/apache/spark/rdd/RDDSuite.scala | 21 +++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 86269eac52db0..ea4ddcc2e265d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1055,7 +1055,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { val partitioners = rdds.flatMap(_.partitioner).toSet - if (partitioners.size == 1) { + if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { new PartitionerAwareUnionRDD(this, rdds) } else { new UnionRDD(this, rdds) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 92b0641d0fb6e..7598ff617b399 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) + require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index df42faab64505..ef8c36a28655b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -99,6 +99,27 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[UnionRDD[_]]) + } + + test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) + } + + test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + intercept[IllegalArgumentException] { + new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) + } + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) From 8e1c00dbf4b60962908626dead744e5d73c8085e Mon Sep 17 00:00:00 2001 From: Hong Shen Date: Mon, 27 Apr 2015 18:57:31 -0400 Subject: [PATCH 081/110] [SPARK-6738] [CORE] Improve estimate the size of a large array Currently, SizeEstimator.visitArray is not correct in the follow case, ``` array size > 200, elem has the share object ``` when I add a debug log in SizeTracker.scala: ``` System.err.println(s"numUpdates:$numUpdates, size:$ts, bytesPerUpdate:$bytesPerUpdate, cost time:$b") ``` I get the following log: ``` numUpdates:1, size:262448, bytesPerUpdate:0.0, cost time:35 numUpdates:2, size:420698, bytesPerUpdate:158250.0, cost time:35 numUpdates:4, size:420754, bytesPerUpdate:28.0, cost time:32 numUpdates:7, size:420754, bytesPerUpdate:0.0, cost time:27 numUpdates:12, size:420754, bytesPerUpdate:0.0, cost time:28 numUpdates:20, size:420754, bytesPerUpdate:0.0, cost time:25 numUpdates:32, size:420754, bytesPerUpdate:0.0, cost time:21 numUpdates:52, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:84, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:135, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:216, size:420754, bytesPerUpdate:0.0, cost time:11 numUpdates:346, size:420754, bytesPerUpdate:0.0, cost time:6 numUpdates:554, size:488911, bytesPerUpdate:327.67788461538464, cost time:8 numUpdates:887, size:2312259426, bytesPerUpdate:6942253.798798799, cost time:198 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 3.0 GB to disk (1 time so far) 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` But in fact the file size is only 162K: ``` $ ll -h /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc -rw-r----- 1 spark users 162K Apr 21 14:27 /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` In order to test case, I change visitArray to: ``` var size = 0l for (i <- 0 until length) { val obj = JArray.get(array, i) size += SizeEstimator.estimate(obj, state.visited).toLong } state.size += size ``` I get the following log: ``` ... 14895 277016088 566.9046118590662 time:8470 23832 281840544 552.3308270676691 time:8031 38132 289891824 539.8294729775092 time:7897 61012 302803640 563.0265734265735 time:13044 97620 322904416 564.3276223776223 time:13554 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 314.5 MB to disk (1 time so far) 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark-local-20150414114020-2fcb/14/temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` the file size is 85M. ``` $ ll -h /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark- local-20150414114020-2fcb/14/ total 85M -rw-r----- 1 spark users 85M Apr 14 11:46 temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` The following log is when I use this patch, ``` .... numUpdates:32, size:365484, bytesPerUpdate:0.0, cost time:7 numUpdates:52, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:84, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:135, size:372208, bytesPerUpdate:131.84313725490196, cost time:86 numUpdates:216, size:379020, bytesPerUpdate:84.09876543209876, cost time:21 numUpdates:346, size:1865208, bytesPerUpdate:11432.215384615385, cost time:23 numUpdates:554, size:2052380, bytesPerUpdate:899.8653846153846, cost time:16 numUpdates:887, size:2142820, bytesPerUpdate:271.59159159159157, cost time:15 .. numUpdates:14895, size:251675500, bytesPerUpdate:438.5263157894737, cost time:13 numUpdates:23832, size:257010268, bytesPerUpdate:596.9305135951662, cost time:14 numUpdates:38132, size:263922396, bytesPerUpdate:483.3655944055944, cost time:15 numUpdates:61012, size:268962596, bytesPerUpdate:220.28846153846155, cost time:24 numUpdates:97620, size:286980644, bytesPerUpdate:492.1888111888112, cost time:22 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: Thread 53 spilling in-memory map of 328.7 MB to disk (1 time so far) 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` the file size is 88M. ``` $ ll -h /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/ total 88M -rw-r----- 1 spark users 88M Apr 21 14:45 temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` Author: Hong Shen Closes #5608 from shenh062326/my_change5 and squashes the following commits: 5506bae [Hong Shen] Fix compile error c275dd3 [Hong Shen] Alter code style fe202a2 [Hong Shen] Change the code style and add documentation. a9fca84 [Hong Shen] Add test case for SizeEstimator 4877eee [Hong Shen] Improve estimate the size of a large array a2ea7ac [Hong Shen] Alter code style 4c28e36 [Hong Shen] Improve estimate the size of a large array --- .../org/apache/spark/util/SizeEstimator.scala | 45 ++++++++++++------- .../spark/util/SizeEstimatorSuite.scala | 18 ++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 26ffbf9350388..4dd7ab9e0767b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging { } // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SIZE_FOR_SAMPLING = 400 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { @@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging { } } else { // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 + // To exclude the shared objects that the array elements may link, sample twice + // and use the min one to caculate array size. val rand = new Random(42) - val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) - var numElementsDrawn = 0 - while (numElementsDrawn < ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] - size += SizeEstimator.estimate(elem, state.visited) - numElementsDrawn += 1 - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) + val s1 = sampleArray(array, state, rand, drawn, length) + val s2 = sampleArray(array, state, rand, drawn, length) + val size = math.min(s1, s2) + state.size += math.max(s1, s2) + + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } } + private def sampleArray( + array: AnyRef, + state: SearchState, + rand: Random, + drawn: OpenHashSet[Int], + length: Int): Long = { + var size = 0L + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] + if (obj != null) { + size += SizeEstimator.estimate(obj, state.visited).toLong + } + } + size + } + private def primitiveSize(cls: Class[_]): Long = { if (cls == classOf[Byte]) { BYTE_SIZE diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 67a9f75ff2187..28915bd53354e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -96,6 +98,22 @@ class SizeEstimatorSuite // Past size 100, our samples 100 elements, but we should still get the right size. assertResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) + + val arr = new Array[Char](100000) + assertResult(200016)(SizeEstimator.estimate(arr)) + assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr)))) + + val buf = new ArrayBuffer[DummyString]() + for (i <- 0 until 5000) { + buf.append(new DummyString(new Array[Char](10))) + } + assertResult(340016)(SizeEstimator.estimate(buf.toArray)) + + for (i <- 0 until 5000) { + buf.append(new DummyString(arr)) + } + assertResult(683912)(SizeEstimator.estimate(buf.toArray)) + // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 // 10 pointers plus 8-byte object From 5d45e1f60059e2f2fc8ad64778b9ddcc8887c570 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 27 Apr 2015 19:46:17 -0400 Subject: [PATCH 082/110] [SPARK-3090] [CORE] Stop SparkContext if user forgets to. Set up a shutdown hook to try to stop the Spark context in case the user forgets to do it. The main effect is that any open logs files are flushed and closed, which is particularly interesting for event logs. Author: Marcelo Vanzin Closes #5696 from vanzin/SPARK-3090 and squashes the following commits: 3b554b5 [Marcelo Vanzin] [SPARK-3090] [core] Stop SparkContext if user forgets to. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++++++++++++------- .../scala/org/apache/spark/util/Utils.scala | 10 ++++- .../spark/deploy/yarn/ApplicationMaster.scala | 10 +---- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ea4ddcc2e265d..65b903a55d5bd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -223,6 +223,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _listenerBusStarted: Boolean = false private var _jars: Seq[String] = _ private var _files: Seq[String] = _ + private var _shutdownHookRef: AnyRef = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -517,6 +518,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _taskScheduler.postStartHook() _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + + // Make sure the context is stopped if the user forgets about it. This avoids leaving + // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM + // is killed, though. + _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + logInfo("Invoking stop() from shutdown hook") + stop() + } } catch { case NonFatal(e) => logError("Error initializing SparkContext.", e) @@ -1481,6 +1490,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo("SparkContext already stopped.") return } + if (_shutdownHookRef != null) { + Utils.removeShutdownHook(_shutdownHookRef) + } postApplicationEnd() _ui.foreach(_.stop()) @@ -1891,7 +1903,7 @@ object SparkContext extends Logging { * * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private val activeContext: AtomicReference[SparkContext] = + private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** @@ -1944,11 +1956,11 @@ object SparkContext extends Logging { } /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, - * this is useful when applications may wish to share a SparkContext. + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * Note: This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -1961,17 +1973,17 @@ object SparkContext extends Logging { activeContext.get() } } - + /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. - * + * * This method allows not passing a SparkConf (useful if just retrieving). - * - * Note: This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. - */ + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ def getOrCreate(): SparkContext = { getOrCreate(new SparkConf()) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c6c6df7cfa56e..342bc9a06db47 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,12 @@ private[spark] object Utils extends Logging { val DEFAULT_SHUTDOWN_PRIORITY = 100 + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -2116,7 +2122,7 @@ private[spark] object Utils extends Logging { * @return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) } /** @@ -2126,7 +2132,7 @@ private[spark] object Utils extends Logging { * @param hook The code to run during shutdown. * @return A handle that can be used to unregister the shutdown hook. */ - def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 93ae45133ce24..70cb57ffd8c69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -95,14 +95,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) - Utils.addShutdownHook { () => - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } + // This shutdown hook should run *after* the SparkContext is shut down. + Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts From ab5adb7a973eec9d95c7575c864cba9f8d83a0fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 27 Apr 2015 19:50:55 -0400 Subject: [PATCH 083/110] [SPARK-7145] [CORE] commons-lang (2.x) classes used instead of commons-lang3 (3.x); commons-io used without dependency Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava Author: Sean Owen Closes #5703 from srowen/SPARK-7145 and squashes the following commits: 21fbe03 [Sean Owen] Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava --- .../test/scala/org/apache/spark/FileServerSuite.scala | 7 +++---- .../apache/spark/metrics/InputOutputMetricsSuite.scala | 4 ++-- .../netty/NettyBlockTransferSecuritySuite.scala | 10 +++++++--- external/flume-sink/pom.xml | 4 ++++ .../flume/sink/SparkAvroCallbackHandler.scala | 4 ++-- .../main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 6 +++++- .../sql/hive/thriftserver/AbstractSparkSQLDriver.scala | 4 ++-- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 8 +++----- .../apache/spark/sql/hive/execution/UDFListString.java | 6 +++--- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 9 ++++----- 10 files changed, 35 insertions(+), 27 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a69e9b761f9a7..c0439f934813e 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -22,8 +22,7 @@ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl.SSLException -import com.google.common.io.ByteStreams -import org.apache.commons.io.{FileUtils, IOUtils} +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite @@ -239,7 +238,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { val randomContent = RandomUtils.nextBytes(100) val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - FileUtils.writeByteArrayToFile(file, randomContent) + Files.write(randomContent, file) server.addFile(file) val uri = new URI(server.serverUri + "/files/" + file.getName) @@ -254,7 +253,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { Utils.setupSecureURLConnection(connection, sm) } - val buf = IOUtils.toByteArray(connection.getInputStream) + val buf = ByteStreams.toByteArray(connection.getInputStream) assert(buf === randomContent) } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 190b08d950a02..ef3e213f1fcce 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang.math.RandomUtils +import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -60,7 +60,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { - pw.println(RandomUtils.nextInt(numBuckets)) + pw.println(RandomUtils.nextInt(0, numBuckets)) } pw.close() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 94bfa67451892..46d2e5173acae 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.network.netty +import java.io.InputStreamReader import java.nio._ +import java.nio.charset.Charset import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.util.{Failure, Success, Try} -import org.apache.commons.io.IOUtils +import com.google.common.io.CharStreams import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -32,7 +34,7 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.{SecurityManager, SparkConf} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} +import org.scalatest.{FunSuite, ShouldMatchers} class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { @@ -113,7 +115,9 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => - IOUtils.toString(buf.createInputStream()) should equal(blockString) + val actualString = CharStreams.toString( + new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + actualString should equal(blockString) buf.release() Success() case Failure(t) => diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 67907bbfb6d1b..1f3e619d97a24 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,6 +35,10 @@ http://spark.apache.org/ + + org.apache.commons + commons-lang3 + org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 4373be443e67d..fd01807fc3ac4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.flume.Channel -import org.apache.commons.lang.RandomStringUtils import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.flume.Channel +import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f326510042122..f3b5455574d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties -import org.apache.commons.lang.StringEscapeUtils.escapeSql +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} @@ -239,6 +240,9 @@ private[sql] class JDBCRDD( case _ => value } + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 59f3a75768082..48ac9062af96a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse @@ -61,7 +61,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7e307bb4ad1e8..b7b6925aa87f7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -24,18 +24,16 @@ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException -import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java index efd34df293c88..f33210ebdae1b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution; -import org.apache.hadoop.hive.ql.exec.UDF; - import java.util.List; -import org.apache.commons.lang.StringUtils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hive.ql.exec.UDF; public class UDFListString extends UDF { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e09c702c8969e..0538aa203c5a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfterEach -import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.ql.metadata.Table @@ -174,7 +173,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -190,7 +189,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("drop, change, recreate") { @@ -212,7 +211,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -231,7 +230,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("invalidate cache and reload") { From 62888a4ded91b3c2cbb05936c374c7ebfc10799e Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Mon, 27 Apr 2015 19:52:41 -0400 Subject: [PATCH 084/110] [SPARK-7162] [YARN] Launcher error in yarn-client jira: https://issues.apache.org/jira/browse/SPARK-7162 Author: GuoQiang Li Closes #5716 from witgo/SPARK-7162 and squashes the following commits: b64564c [GuoQiang Li] Launcher error in yarn-client --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 019afbd1a1743..741239c953794 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -354,7 +354,7 @@ private[spark] class Client( val dir = new File(path) if (dir.isDirectory()) { dir.listFiles().foreach { file => - if (!hadoopConfFiles.contains(file.getName())) { + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { hadoopConfFiles(file.getName()) = file } } From 4d9e560b5470029143926827b1cb9d72a0bfbeff Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 27 Apr 2015 19:02:51 -0700 Subject: [PATCH 085/110] [SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion --- .../spark/examples/mllib/JavaLDAExample.java | 2 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../apache/spark/mllib/clustering/LDA.scala | 181 +++------------ .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 210 ++++++++++++++++++ .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- project/MimaExcludes.scala | 4 + 8 files changed, 256 insertions(+), 151 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index 36207ae38d9a9..fd53c81cc4974 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 08a93595a2e17..a1850390c0a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d006b39acb213..37bf88b73b911 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,16 +17,11 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize} - +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] @@ -63,10 +51,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -220,6 +209,32 @@ class LDA private ( this } + + /** LDAOptimizer used to perform the actual calculation */ + def getOptimizer: LDAOptimizer = ldaOptimizer + + /** + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + */ + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } + + /** + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em" is supported. + */ + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em is supported but got $other.") + } + this + } + /** * Learn an LDA model using the given dataset. * @@ -229,9 +244,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, + seed, checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,12 +256,11 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -320,88 +334,10 @@ private[clustering] object LDA { private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } - /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -427,49 +363,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) - } - } - verticesTMP.reduceByKey(_ + _) - } - - val docTermVertices = createVertices() - - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 0a3f21ecee0dc..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -203,7 +203,7 @@ class DistributedLDAModel private ( import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 0000000000000..ffd72a294c6c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Experimental +trait LDAOptimizer{ + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: Experimental :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + */ +@Experimental +class EMLDAOptimizer extends LDAOptimizer{ + + import LDA._ + + /** + * Following fields will only be initialized through initialState method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + private[clustering] override def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = vocabSize + this.docConcentration = docConcentration + this.topicConcentration = topicConcentration + this.checkpointInterval = checkpointInterval + this.graphCheckpointer = new + PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + private[clustering] override def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(this, iterationTimes) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index cc747dabb9968..41ec794146c69 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ef363a2f07ad..967961c2bf5c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,10 @@ object MimaExcludes { // SPARK-6703 Add getOrCreate method to SparkContext ProblemFilters.exclude[IncompatibleResultTypeProblem] ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") ) case v if v.startsWith("1.3") => From 874a2ca93d095a0dfa1acfdacf0e9d80388c4422 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Apr 2015 21:45:40 -0700 Subject: [PATCH 086/110] [SPARK-7174][Core] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread `HeartbeatReceiver` will call `TaskScheduler.executorHeartbeatReceived`, which is a blocking operation because `TaskScheduler.executorHeartbeatReceived` will call ```Scala blockManagerMaster.driverEndpoint.askWithReply[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) ``` finally. Even if it asks from a local Actor, it may block the current Akka thread. E.g., the reply may be dispatched to the same thread of the ask operation. So the reply cannot be processed. An extreme case is setting the thread number of Akka dispatch thread pool to 1. jstack log: ``` "sparkDriver-akka.actor.default-dispatcher-14" daemon prio=10 tid=0x00007f2a8c02d000 nid=0x725 waiting on condition [0x00007f2b1d6d0000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x00000006197a0868> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:226) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1033) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1326) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread$$anon$3.block(ThreadPoolBuilder.scala:169) at scala.concurrent.forkjoin.ForkJoinPool.managedBlock(ForkJoinPool.java:3640) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread.blockOn(ThreadPoolBuilder.scala:167) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithReply(RpcEnv.scala:355) at org.apache.spark.scheduler.DAGScheduler.executorHeartbeatReceived(DAGScheduler.scala:169) at org.apache.spark.scheduler.TaskSchedulerImpl.executorHeartbeatReceived(TaskSchedulerImpl.scala:367) at org.apache.spark.HeartbeatReceiver$$anonfun$receiveAndReply$1.applyOrElse(HeartbeatReceiver.scala:103) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$processMessage(AkkaRpcEnv.scala:182) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1$$anonfun$applyOrElse$4.apply$mcV$sp(AkkaRpcEnv.scala:128) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$safelyCall(AkkaRpcEnv.scala:203) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1.applyOrElse(AkkaRpcEnv.scala:127) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply$mcVL$sp(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:25) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:59) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:42) at scala.PartialFunction$class.applyOrElse(PartialFunction.scala:118) at org.apache.spark.util.ActorLogReceive$$anon$1.applyOrElse(ActorLogReceive.scala:42) at akka.actor.Actor$class.aroundReceive(Actor.scala:465) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1.aroundReceive(AkkaRpcEnv.scala:94) at akka.actor.ActorCell.receiveMessage(ActorCell.scala:516) at akka.actor.ActorCell.invoke(ActorCell.scala:487) at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:238) at akka.dispatch.Mailbox.run(Mailbox.scala:220) at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(AbstractDispatcher.scala:393) at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260) at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339) at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979) at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107) ``` This PR moved this blocking operation to a separated thread. Author: zsxwing Closes #5723 from zsxwing/SPARK-7174 and squashes the following commits: 98bfe48 [zsxwing] Use a single thread for checking timeout and reporting executorHeartbeatReceived 5b3b545 [zsxwing] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread to avoid blocking the Akka thread pool --- .../org/apache/spark/HeartbeatReceiver.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } From 29576e786072bd4218e10036ddfc8d367b1c1446 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 23:10:14 -0700 Subject: [PATCH 087/110] [SPARK-6829] Added math functions for DataFrames Implemented almost all math functions found in scala.math (max, min and abs were already present). cc mengxr marmbrus Author: Burak Yavuz Closes #5616 from brkyvz/math-udfs and squashes the following commits: fb27153 [Burak Yavuz] reverted exception message 836a098 [Burak Yavuz] fixed test and addressed small comment e5f0d13 [Burak Yavuz] addressed code review v2.2 b26c5fb [Burak Yavuz] addressed review v2.1 2761f08 [Burak Yavuz] addressed review v2 6588a5b [Burak Yavuz] fixed merge conflicts b084e10 [Burak Yavuz] Addressed code review 029e739 [Burak Yavuz] fixed atan2 test 534cc11 [Burak Yavuz] added more tests, addressed comments fa68dbe [Burak Yavuz] added double specific test data 937d5a5 [Burak Yavuz] use doubles instead of ints 8e28fff [Burak Yavuz] Added apache header 7ec8f7f [Burak Yavuz] Added math functions for DataFrames --- .../catalyst/analysis/HiveTypeCoercion.scala | 19 + .../sql/catalyst/expressions/Expression.scala | 10 + .../expressions/mathfuncs/binary.scala | 93 +++ .../expressions/mathfuncs/unary.scala | 168 ++++++ .../ExpressionEvaluationSuite.scala | 165 +++++ .../org/apache/spark/sql/mathfunctions.scala | 562 ++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 + .../spark/sql/ColumnExpressionSuite.scala | 1 - .../spark/sql/MathExpressionsSuite.scala | 233 ++++++++ 9 files changed, 1259 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..5b4d912a64f71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..96cb77d487529 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" +} + +/** + * A unary expression specifically for math functions that take a `Double` as input and return + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) + extends MathematicalExpression(name) { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take an `Int` as input and return + * an `Int`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForInt(f: Int => Int, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Int]) + } +} + +/** + * A unary expression specifically for math functions that take a `Float` as input and return + * a `Float`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = FloatType + override def expectedChildTypes: Seq[DataType] = Seq(FloatType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Float]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take a `Long` as input and return + * a `Long`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForLong(f: Long => Long, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = LongType + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Long]) + } +} + +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") + +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") + +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") + +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") + +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") + +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") + +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") + +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") + +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") + +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") + +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") + +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") + +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") + +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") + +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") + +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") + +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") + +case class ToDegrees(child: Expression) + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") + +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") + +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..5390ce43c6639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("isignum") { + unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) + } + + test("fsignum") { + unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) + } + + test("lsignum") { + unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..84f62bf47f955 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + * @groupname int_funcs Functions that require IntegerType as an input + * @groupname float_funcs Functions that require FloatType as an input + * @groupname long_funcs Functions that require LongType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the sine of the given value. + * + * @group double_funcs + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group double_funcs + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group double_funcs + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group double_funcs + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the cosine of the given value. + * + * @group double_funcs + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group double_funcs + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group double_funcs + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group double_funcs + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group double_funcs + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group double_funcs + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group double_funcs + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group double_funcs + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group double_funcs + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group double_funcs + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group double_funcs + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group double_funcs + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the floor of the given value. + * + * @group double_funcs + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group double_funcs + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group double_funcs + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group double_funcs + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the signum of the given value. + * + * @group double_funcs + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group double_funcs + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the signum of the given value. For IntegerType. + * + * @group int_funcs + */ + def isignum(e: Column): Column = ISignum(e.expr) + + /** + * Computes the signum of the given column. For IntegerType. + * + * @group int_funcs + */ + def isignum(columnName: String): Column = isignum(Column(columnName)) + + /** + * Computes the signum of the given value. For FloatType. + * + * @group float_funcs + */ + def fsignum(e: Column): Column = FSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group float_funcs + */ + def fsignum(columnName: String): Column = fsignum(Column(columnName)) + + /** + * Computes the signum of the given value. For LongType. + * + * @group long_funcs + */ + def lsignum(e: Column): Column = LSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group long_funcs + */ + def lsignum(columnName: String): Column = lsignum(Column(columnName)) + + /** + * Computes the natural logarithm of the given value. + * + * @group double_funcs + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group double_funcs + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group double_funcs + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group double_funcs + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group double_funcs + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group double_funcs + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -41,6 +41,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +99,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 904073b8cb2aa..680b5c636960d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..561553cc925cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} From 9e4e82b7bca1129bcd5e0274b9ae1b1be3fb93da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 27 Apr 2015 23:48:02 -0700 Subject: [PATCH 088/110] [SPARK-5946] [STREAMING] Add Python API for direct Kafka stream Currently only added `createDirectStream` API, I'm not sure if `createRDD` is also needed, since some Java object needs to be wrapped in Python. Please help to review, thanks a lot. Author: jerryshao Author: Saisai Shao Closes #4723 from jerryshao/direct-kafka-python-api and squashes the following commits: a1fe97c [jerryshao] Fix rebase issue eebf333 [jerryshao] Address the comments da40f4e [jerryshao] Fix Python 2.6 Syntax error issue 5c0ee85 [jerryshao] Style fix 4aeac18 [jerryshao] Fix bug in example code 7146d86 [jerryshao] Add unit test bf3bdd6 [jerryshao] Add more APIs and address the comments f5b3801 [jerryshao] Small style fix 8641835 [Saisai Shao] Rebase and update the code 589c05b [Saisai Shao] Fix the style d6fcb6a [Saisai Shao] Address the comments dfda902 [Saisai Shao] Style fix 0f7d168 [Saisai Shao] Add the doc and fix some style issues 67e6880 [Saisai Shao] Fix test bug 917b0db [Saisai Shao] Add Python createRDD API for Kakfa direct stream c3fc11d [jerryshao] Modify the docs 2c00936 [Saisai Shao] address the comments 3360f44 [jerryshao] Fix code style e0e0f0d [jerryshao] Code clean and bug fix 338c41f [Saisai Shao] Add python API and example for direct kafka stream --- .../streaming/direct_kafka_wordcount.py | 55 ++++++ .../spark/streaming/kafka/KafkaUtils.scala | 92 +++++++++- python/pyspark/streaming/kafka.py | 167 +++++++++++++++++- python/pyspark/streaming/tests.py | 84 ++++++++- 4 files changed, 383 insertions(+), 15 deletions(-) create mode 100644 examples/src/main/python/streaming/direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..0721ddaf7055a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -234,7 +235,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +558,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() From bf35edd9d4b8b11df9f47b6ff43831bc95f06322 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 28 Apr 2015 00:38:14 -0700 Subject: [PATCH 089/110] [SPARK-7187] SerializationDebugger should not crash user code rxin Author: Andrew Or Closes #5734 from andrewor14/ser-deb and squashes the following commits: e8aad6c [Andrew Or] NonFatal 57d0ef4 [Andrew Or] try catch improveException --- .../spark/serializer/SerializationDebugger.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } From d94cd1a733d5715792e6c4eac87f0d5c81aebbe2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Apr 2015 00:39:08 -0700 Subject: [PATCH 090/110] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. Author: Reynold Xin Closes #5709 from rxin/inc-id and squashes the following commits: 7853611 [Reynold Xin] private sql. a9fda0d [Reynold Xin] Missed a few numbers. 343d896 [Reynold Xin] Self review feedback. a7136cb [Reynold Xin] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. --- python/pyspark/sql/functions.py | 22 +++++++- .../MonotonicallyIncreasingID.scala | 53 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 6 +-- .../org/apache/spark/sql/functions.scala | 16 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 11 ++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f48b7b5d10af7..7b86655d9c82f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -103,8 +103,28 @@ def countDistinct(col, *cols): return Column(jc) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + def sparkPartitionId(): - """Returns a column for partition ID of the Spark task. + """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index fe7607c6ac340..c2c6cbd491598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { - self: Product => +private[sql] case object SparkPartitionID extends LeafExpression { override type EvaluatedType = Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9738fd4f93bad..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -301,6 +301,22 @@ object functions { */ def lower(e: Column): Column = Lower(e.expr) + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 680b5c636960d..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -309,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + test("sparkPartitionId") { val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( From e13cd86567a43672297bb488088dd8f40ec799bf Mon Sep 17 00:00:00 2001 From: Pei-Lun Lee Date: Tue, 28 Apr 2015 16:50:18 +0800 Subject: [PATCH 091/110] [SPARK-6352] [SQL] Custom parquet output committer Add new config "spark.sql.parquet.output.committer.class" to allow custom parquet output committer and an output committer class specific to use on s3. Fix compilation error introduced by https://github.com/apache/spark/pull/5042. Respect ParquetOutputFormat.ENABLE_JOB_SUMMARY flag. Author: Pei-Lun Lee Closes #5525 from ypcat/spark-6352 and squashes the following commits: 54c6b15 [Pei-Lun Lee] error handling 472870e [Pei-Lun Lee] add back custom parquet output committer ddd0f69 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ece5c5 [Pei-Lun Lee] compatibility with hadoop 1.x 8413fcd [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 fe65915 [Pei-Lun Lee] add support for parquet config parquet.enable.summary-metadata e17bf47 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ae7545 [Pei-Lun Lee] [SPARL-6352] [SQL] Change to allow custom parquet output committer. 0d540b9 [Pei-Lun Lee] [SPARK-6352] [SQL] add license c42468c [Pei-Lun Lee] [SPARK-6352] [SQL] add test case 0fc03ca [Pei-Lun Lee] [SPARK-6532] [SQL] hide class DirectParquetOutputCommitter 769bd67 [Pei-Lun Lee] DirectParquetOutputCommitter f75e261 [Pei-Lun Lee] DirectParquetOutputCommitter --- .../DirectParquetOutputCommitter.scala | 73 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 21 ++++++ .../spark/sql/parquet/ParquetIOSuite.scala | 22 ++++++ 3 files changed, 116 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { From 7f3b3b7eb7d14767124a28ec0062c4d60d6c16fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 28 Apr 2015 07:48:34 -0400 Subject: [PATCH 092/110] [SPARK-7168] [BUILD] Update plugin versions in Maven build and centralize versions Update Maven build plugin versions and centralize plugin version management Author: Sean Owen Closes #5720 from srowen/SPARK-7168 and squashes the following commits: 98a8947 [Sean Owen] Make install, deploy plugin versions explicit 4ecf3b2 [Sean Owen] Update Maven build plugin versions and centralize plugin version management --- assembly/pom.xml | 1 - core/pom.xml | 1 - network/common/pom.xml | 1 - pom.xml | 44 ++++++++++++++++++++++++++++++------------ sql/hive/pom.xml | 1 - 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..459ef66712c36 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -478,7 +478,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..928f5d0f5efad 100644 --- a/pom.xml +++ b/pom.xml @@ -1082,7 +1082,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1105,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1176,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1189,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1260,17 +1260,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1287,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1305,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1335,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1354,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1379,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies From 75905c57cd57bc5b650ac5f486580ef8a229b260 Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Tue, 28 Apr 2015 07:51:02 -0400 Subject: [PATCH 093/110] [SPARK-7100] [MLLIB] Fix persisted RDD leak in GradientBoostTrees This fixes a leak of a persisted RDD where GradientBoostTrees can call persist but never unpersists. Jira: https://issues.apache.org/jira/browse/SPARK-7100 Discussion: http://apache-spark-developers-list.1001551.n3.nabble.com/GradientBoostTrees-leaks-a-persisted-RDD-td11750.html Author: Jim Carroll Closes #5669 from jimfcarroll/gb-unpersist-fix and squashes the following commits: 45f4b03 [Jim Carroll] [SPARK-7100][MLLib] Fix persisted RDD leak in GradientBoostTrees --- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, From 268c419f1586110b90e68f98cd000a782d18828c Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 28 Apr 2015 07:55:21 -0400 Subject: [PATCH 094/110] [SPARK-6435] spark-shell --jars option does not add all jars to classpath Modified to accept double-quotated args properly in spark-shell.cmd. Author: Masayoshi TSUZUKI Closes #5227 from tsudukim/feature/SPARK-6435-2 and squashes the following commits: ac55787 [Masayoshi TSUZUKI] removed unnecessary argument. 60789a7 [Masayoshi TSUZUKI] Merge branch 'master' of https://github.com/apache/spark into feature/SPARK-6435-2 1fee420 [Masayoshi TSUZUKI] fixed test code for escaping '='. 0d4dc41 [Masayoshi TSUZUKI] - escaped comman and semicolon in CommandBuilderUtils.java - added random string to the temporary filename - double-quotation followed by `cmd /c` did not worked properly - no need to escape `=` by `^` - if double-quoted string ended with `\` like classpath, the last `\` is parsed as the escape charactor and the closing `"` didn't work properly 2a332e5 [Masayoshi TSUZUKI] Merge branch 'master' into feature/SPARK-6435-2 04f4291 [Masayoshi TSUZUKI] [SPARK-6435] spark-shell --jars option does not add all jars to classpath --- bin/spark-class2.cmd | 5 ++++- .../org/apache/spark/launcher/CommandBuilderUtils.java | 9 ++++----- .../src/main/java/org/apache/spark/launcher/Main.java | 6 +----- .../apache/spark/launcher/CommandBuilderUtilsSuite.java | 5 ++++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map Date: Tue, 28 Apr 2015 09:46:08 -0700 Subject: [PATCH 095/110] [SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN Author: DB Tsai Author: DB Tsai Closes #4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 34 ++ .../ml/regression/LinearRegression.scala | 304 ++++++++++++++++-- .../apache/spark/mllib/linalg/Vectors.scala | 8 +- .../spark/mllib/optimization/Gradient.scala | 6 +- .../spark/mllib/optimization/LBFGS.scala | 15 +- .../mllib/util/LinearDataGenerator.scala | 43 ++- .../ml/regression/LinearRegressionSuite.scala | 158 +++++++-- 8 files changed, 508 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..3f7e8f5a0b22c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..7d2c76d6c62c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -276,4 +276,38 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * :: DeveloperApi :: + * Trait for shared param elasticNetParam. + */ +@DeveloperApi +trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param tol. + */ +@DeveloperApi +trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..f92c6816eb54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,21 +17,29 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: @@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) if (handlePersistence) { - oldDataset.unpersist() + instances.unpersist() } - lrm + new LinearRegressionModel(this, paramMap, weights, intercept) } } @@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + + * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. + * This is correct for the averaged least squares loss function (mean squared error) + * L = 1/2n ||A weights-y||^2 + * See also the documentation for the precise formulation. + * + * @param weights weights/coefficients corresponding to features + * + * @param updater Updater to be used to update weights after every iteration. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0 + private var lossSum = 0.0 + private var diffSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + diffSum += diff + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + diffSum += other.diffSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + + val correction = { + val temp = effectiveWeightsArray.clone() + var i = 0 + while (i < temp.length) { + temp(i) *= featuresMean(i) + i += 1 + } + Vectors.dense(temp) + } + + axpy(-diffSum, correction, result) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..af0cfe22ca10d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -85,7 +85,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -284,7 +284,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -483,7 +483,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -543,7 +543,7 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..d7bb943e84f53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,47 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.size)(0.0), + Array.fill[Double](weights.size)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble)) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + }) + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } From b14cd2364932e504695bcc49486ffb4518fdf33d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 09:59:36 -0700 Subject: [PATCH 096/110] [SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng Closes #5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode --- .../apache/spark/mllib/linalg/Vectors.scala | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index af0cfe22ca10d..34833e90d4af0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -317,7 +325,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +379,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +409,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +430,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +459,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } } object DenseVector { @@ -522,8 +546,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -547,7 +571,7 @@ class SparseVector( private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +580,28 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } } object SparseVector { From 52ccf1d3739694826915cdf01642bab02958eb78 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 10:24:00 -0700 Subject: [PATCH 097/110] [Core][test][minor] replace try finally block with tryWithSafeFinally Author: Zhang, Liye Closes #5739 from liyezhang556520/trySafeFinally and squashes the following commits: 55683e5 [Zhang, Liye] replace try finally block with tryWithSafeFinally --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } From 8aab94d8984e9d12194dbda47b2e7d9dbc036889 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Tue, 28 Apr 2015 12:08:18 -0700 Subject: [PATCH 098/110] [SPARK-4286] Add an external shuffle service that can be run as a daemon. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows Mesos deployments to use the shuffle service (and implicitly dynamic allocation). It does so by adding a new "main" class and two corresponding scripts in `sbin`: - `sbin/start-shuffle-service.sh` - `sbin/stop-shuffle-service.sh` Specific options can be passed in `SPARK_SHUFFLE_OPTS`. This is picking up work from #3861 /cc tnachen Author: Iulian Dragos Closes #4990 from dragos/feature/external-shuffle-service and squashes the following commits: 6c2b148 [Iulian Dragos] Import order and wrong name fixup. 07804ad [Iulian Dragos] Moved ExternalShuffleService to the `deploy` package + other minor tweaks. 4dc1f91 [Iulian Dragos] Reviewer’s comments: 8145429 [Iulian Dragos] Add an external shuffle service that can be run as a daemon. --- conf/spark-env.sh.template | 3 +- ...ice.scala => ExternalShuffleService.scala} | 59 ++++++++++++++++--- .../apache/spark/deploy/worker/Worker.scala | 13 ++-- docs/job-scheduling.md | 2 +- .../launcher/SparkClassCommandBuilder.java | 4 ++ sbin/start-shuffle-service.sh | 33 +++++++++++ sbin/stop-shuffle-service.sh | 25 ++++++++ 7 files changed, 124 insertions(+), 15 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{worker/StandaloneWorkerShuffleService.scala => ExternalShuffleService.scala} (59%) create mode 100755 sbin/start-shuffle-service.sh create mode 100755 sbin/stop-shuffle-service.sh diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index e601a0a19f368..d80abf2a8676e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -69,6 +69,10 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 From 2d222fb39dd978e5a33cde6ceb59307cbdf7b171 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Tue, 28 Apr 2015 12:18:55 -0700 Subject: [PATCH 099/110] [SPARK-5932] [CORE] Use consistent naming for size properties I've added an interface to JavaUtils to do byte conversion and added hooks within Utils.scala to handle conversion within Spark code (like for time strings). I've added matching tests for size conversion, and then updated all deprecated configs and documentation as per SPARK-5933. Author: Ilya Ganelin Closes #5574 from ilganeli/SPARK-5932 and squashes the following commits: 11f6999 [Ilya Ganelin] Nit fixes 49a8720 [Ilya Ganelin] Whitespace fix 2ab886b [Ilya Ganelin] Scala style fc85733 [Ilya Ganelin] Got rid of floating point math 852a407 [Ilya Ganelin] [SPARK-5932] Added much improved overflow handling. Can now handle sizes up to Long.MAX_VALUE Petabytes instead of being capped at Long.MAX_VALUE Bytes 9ee779c [Ilya Ganelin] Simplified fraction matches 22413b1 [Ilya Ganelin] Made MAX private 3dfae96 [Ilya Ganelin] Fixed some nits. Added automatic conversion of old paramter for kryoserializer.mb to new values. e428049 [Ilya Ganelin] resolving merge conflict 8b43748 [Ilya Ganelin] Fixed error in pattern matching for doubles 84a2581 [Ilya Ganelin] Added smoother handling of fractional values for size parameters. This now throws an exception and added a warning for old spark.kryoserializer.buffer d3d09b6 [Ilya Ganelin] [SPARK-5932] Fixing error in KryoSerializer fe286b4 [Ilya Ganelin] Resolved merge conflict c7803cd [Ilya Ganelin] Empty lines 54b78b4 [Ilya Ganelin] Simplified byteUnit class 69e2f20 [Ilya Ganelin] Updates to code f32bc01 [Ilya Ganelin] [SPARK-5932] Fixed error in API in SparkConf.scala where Kb conversion wasn't being done properly (was Mb). Added test cases for both timeUnit and ByteUnit conversion f15f209 [Ilya Ganelin] Fixed conversion of kryo buffer size 0f4443e [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 35a7fa7 [Ilya Ganelin] Minor formatting 928469e [Ilya Ganelin] [SPARK-5932] Converted some longs to ints 5d29f90 [Ilya Ganelin] [SPARK-5932] Finished documentation updates 7a6c847 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer afc9a38 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize and spark.storage.memoryMapThreshold ae7e9f6 [Ilya Ganelin] [SPARK-5932] Updated spark.io.compression.snappy.block.size 2d15681 [Ilya Ganelin] [SPARK-5932] Updated spark.executor.logs.rolling.size.maxBytes 1fbd435 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize eba4de6 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer.kb b809a78 [Ilya Ganelin] [SPARK-5932] Updated spark.kryoserializer.buffer.max 0cdff35 [Ilya Ganelin] [SPARK-5932] Updated to use bibibytes in method names. Updated spark.kryoserializer.buffer.mb and spark.reducer.maxMbInFlight 475370a [Ilya Ganelin] [SPARK-5932] Simplified ByteUnit code, switched to using longs. Updated docs to clarify that we use kibi, mebi etc instead of kilo, mega 851d691 [Ilya Ganelin] [SPARK-5932] Updated memoryStringToMb to use new interfaces a9f4fcf [Ilya Ganelin] [SPARK-5932] Added unit tests for unit conversion 747393a [Ilya Ganelin] [SPARK-5932] Added unit tests for ByteString conversion 09ea450 [Ilya Ganelin] [SPARK-5932] Added byte string conversion to Jav utils 5390fd9 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkConf.scala | 90 ++++++++++++++- .../spark/broadcast/TorrentBroadcast.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 8 +- .../spark/serializer/KryoSerializer.scala | 17 +-- .../shuffle/FileShuffleBlockManager.scala | 3 +- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../org/apache/spark/storage/DiskStore.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++--- .../collection/ExternalAppendOnlyMap.scala | 6 +- .../util/collection/ExternalSorter.scala | 4 +- .../util/logging/RollingFileAppender.scala | 2 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 19 ++++ .../KryoSerializerResizableOutputSuite.scala | 8 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 100 +++++++++++++++- docs/configuration.md | 60 ++++++---- docs/tuning.md | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 2 +- .../apache/spark/network/util/ByteUnit.java | 67 +++++++++++ .../apache/spark/network/util/JavaUtils.java | 107 ++++++++++++++++-- 23 files changed, 488 insertions(+), 81 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 342bc9a06db47..4c028c06a5138 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1020,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..ef3cac622505e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..6957bc72e9903 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - + @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.logs.rolling.size.maxBytesspark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs.
- - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } From 80098109d908b738b43d397e024756ff617d0af4 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 12:33:48 -0700 Subject: [PATCH 100/110] [SPARK-6314] [CORE] handle JsonParseException for history server This is handled in the same way with [SPARK-6197](https://issues.apache.org/jira/browse/SPARK-6197). The result of this PR is that exception showed in history server log will be replaced by a warning, and the application that with un-complete history log file will be listed on history server webUI Author: Zhang, Liye Closes #5736 from liyezhang556520/SPARK-6314 and squashes the following commits: b8d2d88 [Zhang, Liye] handle JsonParseException for history server --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() } From 53befacced828bbac53c6e3a4976ec3f036bae9e Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 28 Apr 2015 13:31:08 -0700 Subject: [PATCH 101/110] [SPARK-5338] [MESOS] Add cluster mode support for Mesos This patch adds the support for cluster mode to run on Mesos. It introduces a new Mesos framework dedicated to launch new apps/drivers, and can be called with the spark-submit script and specifying --master flag to the cluster mode REST interface instead of Mesos master. Example: ./bin/spark-submit --deploy-mode cluster --class org.apache.spark.examples.SparkPi --master mesos://10.0.0.206:8077 --executor-memory 1G --total-executor-cores 100 examples/target/spark-examples_2.10-1.3.0-SNAPSHOT.jar 30 Part of this patch is also to abstract the StandaloneRestServer so it can have different implementations of the REST endpoints. Features of the cluster mode in this PR: - Supports supervise mode where scheduler will keep trying to reschedule exited job. - Adds a new UI for the cluster mode scheduler to see all the running jobs, finished jobs, and supervise jobs waiting to be retried - Supports state persistence to ZK, so when the cluster scheduler fails over it can pick up all the queued and running jobs Author: Timothy Chen Author: Luc Bourlier Closes #5144 from tnachen/mesos_cluster_mode and squashes the following commits: 069e946 [Timothy Chen] Fix rebase. e24b512 [Timothy Chen] Persist submitted driver. 390c491 [Timothy Chen] Fix zk conf key for mesos zk engine. e324ac1 [Timothy Chen] Fix merge. fd5259d [Timothy Chen] Address review comments. 1553230 [Timothy Chen] Address review comments. c6c6b73 [Timothy Chen] Pass spark properties to mesos cluster tasks. f7d8046 [Timothy Chen] Change app name to spark cluster. 17f93a2 [Timothy Chen] Fix head of line blocking in scheduling drivers. 6ff8e5c [Timothy Chen] Address comments and add logging. df355cd [Timothy Chen] Add metrics to mesos cluster scheduler. 20f7284 [Timothy Chen] Address review comments 7252612 [Timothy Chen] Fix tests. a46ad66 [Timothy Chen] Allow zk cli param override. 920fc4b [Timothy Chen] Fix scala style issues. 862b5b5 [Timothy Chen] Support asking driver status when it's retrying. 7f214c2 [Timothy Chen] Fix RetryState visibility e0f33f7 [Timothy Chen] Add supervise support and persist retries. 371ce65 [Timothy Chen] Handle cluster mode recovery and state persistence. 3d4dfa1 [Luc Bourlier] Adds support to kill submissions febfaba [Timothy Chen] Bound the finished drivers in memory 543a98d [Timothy Chen] Schedule multiple jobs 6887e5e [Timothy Chen] Support looking at SPARK_EXECUTOR_URI env variable in schedulers 8ec76bc [Timothy Chen] Fix Mesos dispatcher UI. d57d77d [Timothy Chen] Add documentation 825afa0 [Luc Bourlier] Supports more spark-submit parameters b8e7181 [Luc Bourlier] Adds a shutdown latch to keep the deamon running 0fa7780 [Luc Bourlier] Launch task through the mesos scheduler 5b7a12b [Timothy Chen] WIP: Making a cluster mode a mesos framework. 4b2f5ef [Timothy Chen] Specify user jar in command to be replaced with local. e775001 [Timothy Chen] Support fetching remote uris in driver runner. 7179495 [Timothy Chen] Change Driver page output and add logging 880bc27 [Timothy Chen] Add Mesos Cluster UI to display driver results 9986731 [Timothy Chen] Kill drivers when shutdown 67cbc18 [Timothy Chen] Rename StandaloneRestClient to RestClient and add sbin scripts e3facdd [Timothy Chen] Add Mesos Cluster dispatcher --- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../{master => }/SparkCuratorUtil.scala | 10 +- .../org/apache/spark/deploy/SparkSubmit.scala | 48 +- .../spark/deploy/SparkSubmitArguments.scala | 11 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../master/ZooKeeperLeaderElectionAgent.scala | 1 + .../master/ZooKeeperPersistenceEngine.scala | 1 + .../deploy/mesos/MesosClusterDispatcher.scala | 116 ++++ .../MesosClusterDispatcherArguments.scala | 101 +++ .../deploy/mesos/MesosDriverDescription.scala | 65 ++ .../deploy/mesos/ui/MesosClusterPage.scala | 114 ++++ .../deploy/mesos/ui/MesosClusterUI.scala | 48 ++ ...lient.scala => RestSubmissionClient.scala} | 35 +- .../deploy/rest/RestSubmissionServer.scala | 318 +++++++++ .../deploy/rest/StandaloneRestServer.scala | 344 ++-------- .../rest/SubmitRestProtocolRequest.scala | 2 +- .../rest/SubmitRestProtocolResponse.scala | 6 +- .../deploy/rest/mesos/MesosRestServer.scala | 158 +++++ .../mesos/CoarseMesosSchedulerBackend.scala | 82 +-- .../mesos/MesosClusterPersistenceEngine.scala | 134 ++++ .../cluster/mesos/MesosClusterScheduler.scala | 608 ++++++++++++++++++ .../mesos/MesosClusterSchedulerSource.scala | 40 ++ .../cluster/mesos/MesosSchedulerBackend.scala | 85 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 95 +++ .../spark/deploy/SparkSubmitSuite.scala | 2 +- .../rest/StandaloneRestSubmitSuite.scala | 46 +- .../mesos/MesosClusterSchedulerSuite.scala | 76 +++ docs/running-on-mesos.md | 23 +- sbin/start-mesos-dispatcher.sh | 40 ++ sbin/stop-mesos-dispatcher.sh | 27 + 30 files changed, 2147 insertions(+), 493 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{master => }/SparkCuratorUtil.scala (89%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala rename core/src/main/scala/org/apache/spark/deploy/rest/{StandaloneRestClient.scala => RestSubmissionClient.scala} (93%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala create mode 100755 sbin/start-mesos-dispatcher.sh create mode 100755 sbin/stop-mesos-dispatcher.sh diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a7c89276a045e..c048b78910f38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 5b22481ea8c5f..b8d3993540220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy import scala.collection.JavaConversions._ @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -private[deploy] object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 private val ZK_SESSION_TIMEOUT_MILLIS = 60000 private val RETRY_WAIT_MILLIS = 5000 private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 296a0764b8baf..f4f572e1e256e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -36,11 +36,11 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} - import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. @@ -114,18 +114,20 @@ object SparkSubmit { } } - /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + /** + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .killSubmission(args.master, args.submissionToKill) } /** * Request the status of an existing submission using the REST protocol. - * Standalone cluster mode only. + * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } @@ -252,6 +254,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -294,8 +297,9 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -377,15 +381,6 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - // Do not set CL arguments here because there are multiple possibilities for the main class - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), - OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, - sysProp = "spark.driver.supervise"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), @@ -413,7 +408,15 @@ object SparkSubmit { OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -452,7 +455,7 @@ object SparkSubmit { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class @@ -496,6 +499,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c896842943f2b..c621b8fc86f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -241,8 +241,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://")) { - SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") @@ -250,9 +251,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( - "Requesting submission statuses is only supported in standalone mode!") + "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") @@ -485,6 +486,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff2eed6dee70a..1c21c179562ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -130,7 +130,7 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 4823fd7cac0cb..52758d6a7c4be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a285783f72000..80db6d474b5c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -26,6 +26,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 0000000000000..5d4e5b899dfdc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 0000000000000..894cb78d8591a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + System.err.println("--master is required") + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 0000000000000..1948226800afe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 0000000000000..7b2005e0f1237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

Mesos Framework ID: {state.frameworkId}

+
+
+

Queued Drivers:

+ {queuedTable} +

Launched Drivers:

+ {launchedTable} +

Finished Drivers:

+ {finishedTable} +

Supervise drivers waiting for retry:

+ {retryTable} +
+
; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { +
+ + + + + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + + + + + + + + + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + + + + + + + + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 0000000000000..4865d46dbc4ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index b8fd406fb6f9a..307cebfb4bd09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -30,9 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using a REST protocol. - * This client is intended to communicate with the [[StandaloneRestServer]] and is - * currently used for cluster mode only. + * A client that submits applications to a [[RestSubmissionServer]]. * * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], * where [action] can be one of create, kill, or status. Each type of request is represented in @@ -53,8 +51,10 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[deploy] class StandaloneRestClient extends Logging { - import StandaloneRestClient._ +private[spark] class RestSubmissionClient extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") /** * Submit an application specified by the parameters in the provided request. @@ -62,7 +62,7 @@ private[deploy] class StandaloneRestClient extends Logging { * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - private[rest] def createSubmission( + def createSubmission( master: String, request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") @@ -107,7 +107,7 @@ private[deploy] class StandaloneRestClient extends Logging { } /** Construct a message that captures the specified parameters for submitting an application. */ - private[rest] def constructSubmitRequest( + def constructSubmitRequest( appResource: String, mainClass: String, appArgs: Array[String], @@ -219,14 +219,23 @@ private[deploy] class StandaloneRestClient extends Logging { /** Return the base URL for communicating with the server, including the protocol version. */ private def getBaseUrl(master: String): String = { - val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") s"http://$masterUrl/$PROTOCOL_VERSION/submissions" } /** Throw an exception if this is not standalone mode. */ private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) } } @@ -295,7 +304,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } -private[rest] object StandaloneRestClient { +private[spark] object RestSubmissionClient { private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -315,7 +324,7 @@ private[rest] object StandaloneRestClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient + val client = new RestSubmissionClient val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) client.createSubmission(master, submitRequest) @@ -323,7 +332,7 @@ private[rest] object StandaloneRestClient { def main(args: Array[String]): Unit = { if (args.size < 2) { - sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) } val appResource = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 0000000000000..2e78d03e5c0cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 2d6b8d4204795..502b9bb701ccf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,26 +18,16 @@ package org.apache.spark.deploy.rest import java.io.File -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} - -import scala.io.Source +import javax.servlet.http.HttpServletResponse import akka.actor.ActorRef -import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** - * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * A server that responds to requests submitted by the [[RestSubmissionClient]]. * This is intended to be embedded in the standalone Master and used in cluster mode only. * * This server responds with different HTTP codes depending on the situation: @@ -54,173 +44,31 @@ import org.apache.spark.deploy.ClientArguments._ * * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master * @param masterActor reference to the Master actor to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to - * @param masterConf the conf used by the Master */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends Logging { - - import StandaloneRestServer._ - - private var _server: Option[Server] = None - - // A mapping from URL prefixes to servlets that serve them. Exposed for testing. - protected val baseContext = s"/$PROTOCOL_VERSION/submissions" - protected val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), - s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), - s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), - "/*" -> new ErrorServlet // default handler - ) - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - /** - * Map the servlets to their corresponding contexts and attach them to a server. - * Return a 2-tuple of the started server and the bound port. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val mainHandler = new ServletContextHandler - mainHandler.setContextPath("/") - contextToServlet.foreach { case (prefix, servlet) => - mainHandler.addServlet(new ServletHolder(servlet), prefix) - } - server.setHandler(mainHandler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } -} - -private[rest] object StandaloneRestServer { - val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION - val SC_UNKNOWN_PROTOCOL_VERSION = 468 -} - -/** - * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. - */ -private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** - * Serialize the given response message to JSON and send it through the response servlet. - * This validates the response before sending it to ensure it is properly constructed. - */ - protected def sendResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): Unit = { - val message = validateResponse(responseMessage, responseServlet) - responseServlet.setContentType("application/json") - responseServlet.setCharacterEncoding("utf-8") - responseServlet.getWriter.write(message.toJson) - } - - /** - * Return any fields in the client request message that the server does not know about. - * - * The mechanism for this is to reconstruct the JSON on the server side and compare the - * diff between this JSON and the one generated on the client side. Any fields that are - * only in the client JSON are treated as unexpected. - */ - protected def findUnknownFields( - requestJson: String, - requestMessage: SubmitRestProtocolMessage): Array[String] = { - val clientSideJson = parse(requestJson) - val serverSideJson = parse(requestMessage.toJson) - val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) - unknown match { - case j: JObject => j.obj.map { case (k, _) => k }.toArray - case _ => Array.empty[String] // No difference - } - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Throwable): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - protected def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** - * Parse a submission ID from the relative path, assuming it is the first part of the path. - * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. - * The returned submission ID cannot be empty. If the path is unexpected, return None. - */ - protected def parseSubmissionId(path: String): Option[String] = { - if (path == null || path.isEmpty) { - None - } else { - path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) - } - } - - /** - * Validate the response to ensure that it is correctly constructed. - * - * If it is, simply return the message as is. Otherwise, return an error response instead - * to propagate the exception back to the client and set the appropriate error code. - */ - private def validateResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - try { - responseMessage.validate() - responseMessage - } catch { - case e: Exception => - responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - handleError("Internal server error: " + formatException(e)) - } - } + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterActor, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterActor, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, have the Master kill the corresponding - * driver and return an appropriate response to the client. Otherwise, return error. - */ - protected override def doPost( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleKill).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -238,23 +86,8 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, request the status of the corresponding - * driver from the Master and include it in the response. Otherwise, return error. - */ - protected override def doGet( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleStatus).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -276,71 +109,11 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private[rest] class SubmitRequestServlet( +private[rest] class StandaloneSubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * Submit an application to the Master with parameters specified in the request. - * - * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. - * If the request is successfully processed, return an appropriate response to the - * client indicating so. Otherwise, return error instead. - */ - protected override def doPost( - requestServlet: HttpServletRequest, - responseServlet: HttpServletResponse): Unit = { - val responseMessage = - try { - val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - // The response should have already been validated on the client. - // In case this is not true, validate it ourselves to avoid potential NPEs. - requestMessage.validate() - handleSubmit(requestMessageJson, requestMessage, responseServlet) - } catch { - // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Malformed request: " + formatException(e)) - } - sendResponse(responseMessage, responseServlet) - } - - /** - * Handle the submit request and construct an appropriate response to return to the client. - * - * This assumes that the request message is already successfully validated. - * If the request message is not of the expected type, return error to the client. - */ - private def handleSubmit( - requestMessageJson: String, - requestMessage: SubmitRestProtocolMessage, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - requestMessage match { - case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) - val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val submitResponse = new CreateSubmissionResponse - submitResponse.serverSparkVersion = sparkVersion - submitResponse.message = response.message - submitResponse.success = response.success - submitResponse.submissionId = response.driverId.orNull - val unknownFields = findUnknownFields(requestMessageJson, requestMessage) - if (unknownFields.nonEmpty) { - // If there are fields that the server does not know about, warn the client - submitResponse.unknownFields = unknownFields - } - submitResponse - case unexpected => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError(s"Received message of unexpected type ${unexpected.messageType}.") - } - } + extends SubmitRequestServlet { /** * Build a driver description from the fields specified in the submit request. @@ -389,50 +162,37 @@ private[rest] class SubmitRequestServlet( new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } -} -/** - * A default servlet that handles error cases that are not captured by other servlets. - */ -private class ErrorServlet extends StandaloneRestServlet { - private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION - - /** Service a faulty request by returning an appropriate error message to the client. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList - var versionMismatch = false - var msg = - parts match { - case Nil => - // http://host:port/ - "Missing protocol version." - case `serverVersion` :: Nil => - // http://host:port/correct-version - "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: tail => - // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: tail => - // http://host:port/unknown-version/* - versionMismatch = true - s"Unknown protocol version '$unknownVersion'." - case _ => - // never reached - s"Malformed path $path." - } - msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." - val error = handleError(msg) - // If there is a version mismatch, include the highest protocol version that - // this server supports in case the client wants to retry with our version - if (versionMismatch) { - error.highestProtocolVersion = serverVersion - response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") } - sendResponse(error, response) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index d80abdf15fb34..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -61,7 +61,7 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertProperty[Int](key, "numeric", _.toInt) + assertProperty[Double](key, "numeric", _.toDouble) private def assertPropertyIsMemory(key: String): Unit = assertProperty[Int](key, "memory", Utils.memoryStringToMb) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 8fde8c142a4c1..0e226ee294cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -35,7 +35,7 @@ private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtoc /** * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. */ -private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -46,7 +46,7 @@ private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse /** * A response to a kill request in the REST application submission protocol. */ -private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -58,7 +58,7 @@ private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { /** * A response to a status request in the REST application submission protocol. */ -private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { var submissionId: String = null var driverState: String = null var workerId: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 0000000000000..fd17a980c9319 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[deploy] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 82f652dae0378..3412301e64fd7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -49,17 +46,10 @@ private[spark] class CoarseMesosSchedulerBackend( master: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt @@ -87,26 +77,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } def createCommand(offer: Offer, numCores: Int): CommandInfo = { @@ -150,8 +122,10 @@ private[spark] class CoarseMesosSchedulerBackend( conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" @@ -164,7 +138,7 @@ private[spark] class CoarseMesosSchedulerBackend( } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + @@ -173,7 +147,7 @@ private[spark] class CoarseMesosSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } command.build() } @@ -183,18 +157,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -245,14 +208,6 @@ private[spark] class CoarseMesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Build a Mesos resource protobuf object */ private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() @@ -284,7 +239,8 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + // In case we'd rejected everything before but have now lost a node + mesosDriver.reviveOffers() } } } @@ -296,8 +252,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 0000000000000..3efc536f1456c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 0000000000000..0396e62be5309 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + val builder = FrameworkInfo.newBuilder() + .setUser(Utils.getCurrentUserName()) + .setName(appName) + .setWebuiUrl(frameworkUrl) + .setCheckpoint(true) + .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash + fwId.foreach { id => + builder.setId(FrameworkID.newBuilder().setValue(id).build()) + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + startScheduler(master, MesosClusterScheduler.this, builder.build()) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc) + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val cmd = if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } + builder.setValue(cmd) + builder.setEnvironment(envBuilder.build()) + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = Resource.newBuilder() + .setName("cpus").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() + val memResource = Resource.newBuilder() + .setName("mem").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + .build() + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date()) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.map { o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + }.toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + scheduleTasks( + driversToRetry, + removeFromPendingRetryDrivers, + currentOffers, + tasks) + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + queuedDrivers, + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, tasks) => + driver.launchTasks(Collections.singleton(offerId), tasks) + } + offers + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + def getSchedulerState(): MesosClusterSchedulerState = { + def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 0000000000000..1fe94974c8e36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d9d62b0e287ed..8346a2407489f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,23 +18,19 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -47,14 +43,7 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null + with MesosSchedulerUtils { // Which slave IDs we have executors on val slaveIdsWithExecutors = new HashSet[String] @@ -73,26 +62,9 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + classLoader = Thread.currentThread.getContextClassLoader + startScheduler(master, MesosSchedulerBackend.this, fwInfo) } def createExecutorInfo(execId: String): MesosExecutorInfo = { @@ -125,17 +97,19 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } val cpus = Resource.newBuilder() .setName("cpus") @@ -181,18 +155,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -287,14 +250,6 @@ private[spark] class MesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Turn a Spark TaskDescription into a Mesos task */ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -339,13 +294,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -380,7 +335,7 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 0000000000000..d11228f3d016a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.List +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConversions._ + +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} +import org.apache.mesos.{MesosSchedulerDriver, Scheduler} +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: MesosSchedulerDriver = null + + /** + * Starts the MesosSchedulerDriver with the provided information. This method returns + * only after the scheduler has registered with Mesos. + * @param masterUrl Mesos master connection URL + * @param scheduler Scheduler object + * @param fwInfo FrameworkInfo to pass to the Mesos master + */ + def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + /** + * Get the amount of resources for the specified type from the resource list + */ + protected def getResource(res: List[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + 0.0 + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4561e5b8e9663..c4e6f06146b0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -231,7 +231,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 8e09976636386..0a318a27ac212 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,9 +39,9 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new StandaloneRestClient + private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None - private var server: Option[StandaloneRestServer] = None + private var server: Option[RestSubmissionServer] = None override def afterEach() { actorSystem.foreach(_.shutdown()) @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -208,7 +208,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val json = constructSubmitRequest(masterUrl).toJson val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" @@ -238,7 +238,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths, bad requests") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" val statusRequestPath = s"$httpUrl/$v/submissions/status" @@ -276,7 +276,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("bad request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") @@ -292,7 +292,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(code5 === HttpServletResponse.SC_BAD_REQUEST) assert(code6 === HttpServletResponse.SC_BAD_REQUEST) assert(code7 === HttpServletResponse.SC_BAD_REQUEST) - assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + assert(code8 === RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) // all responses should be error responses val errorResponse1 = getErrorResponse(response1) val errorResponse2 = getErrorResponse(response2) @@ -310,13 +310,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(errorResponse5.highestProtocolVersion === null) assert(errorResponse6.highestProtocolVersion === null) assert(errorResponse7.highestProtocolVersion === null) - assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + assert(errorResponse8.highestProtocolVersion === RestSubmissionServer.PROTOCOL_VERSION) } test("server returns unknown fields") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val oldJson = constructSubmitRequest(masterUrl).toJson val oldFields = parse(oldJson).asInstanceOf[JObject].obj @@ -340,7 +340,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" @@ -400,9 +400,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) val _server = if (faulty) { - new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } else { - new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } val port = _server.start() // set these to clean them up after every test @@ -563,20 +563,18 @@ private class SmarterMaster extends Actor { private class FaultyStandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { - protected override val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new MalformedSubmitServlet, - s"$baseContext/kill/*" -> new InvalidKillServlet, - s"$baseContext/status/*" -> new ExplodingStatusServlet, - "/*" -> new ErrorServlet - ) + protected override val submitRequestServlet = new MalformedSubmitServlet + protected override val killRequestServlet = new InvalidKillServlet + protected override val statusRequestServlet = new ExplodingStatusServlet /** A faulty servlet that produces malformed responses. */ - class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + class MalformedSubmitServlet + extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -586,7 +584,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -595,7 +593,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..f28e29e9b8d8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.util.Date + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.{LocalSparkContext, SparkConf} + + +class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), null, null, null, null) + + test("can queue drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 594bf78b67713..8f53d8201a089 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,7 +110,11 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh new file mode 100755 index 0000000000000..ef1fc573d5c65 --- /dev/null +++ b/sbin/start-mesos-dispatcher.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Starts the Mesos Cluster Dispatcher on the machine this script is executed on. +# The Mesos Cluster Dispatcher is responsible for launching the Mesos framework and +# Rest server to handle driver requests for Mesos cluster mode. +# Only one cluster dispatcher is needed per Mesos cluster. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then + SPARK_MESOS_DISPATCHER_PORT=7077 +fi + +if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then + SPARK_MESOS_DISPATCHER_HOST=`hostname` +fi + + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh new file mode 100755 index 0000000000000..cb65d95b5e524 --- /dev/null +++ b/sbin/stop-mesos-dispatcher.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Stop the Mesos Cluster dispatcher on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +. "$sbin/spark-config.sh" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 + From 28b1af7420e0b5e7e2dfc09eafc45fe2ffcde5ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 28 Apr 2015 13:49:29 -0700 Subject: [PATCH 102/110] [MINOR] [CORE] Warn users who try to cache RDDs with dynamic allocation on. Author: Marcelo Vanzin Closes #5751 from vanzin/cached-rdd-warning and squashes the following commits: 554cc07 [Marcelo Vanzin] Change message. 9efb9da [Marcelo Vanzin] [minor] [core] Warn users who try to cache RDDs with dynamic allocation on. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65b903a55d5bd..d0cf2a8dd01cd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1396,6 +1396,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd } From f0a1f90f53b447c61f405cdb2c553f3dc067bf8d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 14:07:26 -0700 Subject: [PATCH 103/110] [SPARK-7201] [MLLIB] move Identifiable to ml.util It shouldn't live directly under `spark.ml`. Author: Xiangrui Meng Closes #5749 from mengxr/SPARK-7201 and squashes the following commits: 53847f9 [Xiangrui Meng] move Identifiable to ml.util --- mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala | 1 + mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- .../scala/org/apache/spark/ml/{ => util}/Identifiable.scala | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/{ => util}/Identifiable.scala (97%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index d2ca2e6871e6b..8b4b5fd8af986 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ddc5907e7facd..014e124e440a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.Identifiable +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala rename to mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index a1d49095c24ac..8a56748ab0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.util import java.util.UUID From 555213ebbf2be2ee523be8665bd5b9a47ae4bec8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 14:21:25 -0700 Subject: [PATCH 104/110] Closes #4807 Closes #5055 Closes #3583 From d36e67350c516a96d58abd50a0d5d22b3b22f291 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 17:41:09 -0700 Subject: [PATCH 105/110] [SPARK-6965] [MLLIB] StringIndexer handles numeric input. Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley Author: Xiangrui Meng Closes #5753 from mengxr/SPARK-6965 and squashes the following commits: 2e34f3c [Xiangrui Meng] add actual type in the error message ad938bf [Xiangrui Meng] StringIndexer handles numeric input. --- .../spark/ml/feature/StringIndexer.scala | 17 ++++++++++++----- .../spark/ml/feature/StringIndexerSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c8a6..9db3b29e10d69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d82f1..b6939e5870410 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } } From 5c8f4bd5fae539ab5fb992573d5357ed34e2f4d0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 28 Apr 2015 19:31:57 -0700 Subject: [PATCH 106/110] [SPARK-7138] [STREAMING] Add method to BlockGenerator to add multiple records to BlockGenerator with single callback This is to ensure that receivers that receive data in small batches (like Kinesis) and want to add them but want the callback function to be called only once. This is for internal use only for improvement to Kinesis Receiver that we are planning to do. Author: Tathagata Das Closes #5695 from tdas/SPARK-7138 and squashes the following commits: a35cf7d [Tathagata Das] Fixed style. a7a4cb9 [Tathagata Das] Added extra method to BlockGenerator. --- .../spark/streaming/receiver/BlockGenerator.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index f4963a78e1d18..4bebcc5aa7ca0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -126,6 +126,20 @@ private[streaming] class BlockGenerator( listener.onAddData(data, metadata) } + /** + * Push multiple data items into the buffer. After buffering the data, the + * `BlockGeneratorListener.onAddData` callback will be called. All received data items + * will be periodically pushed into BlockManager. Note that all the data items is guaranteed + * to be present in a single block. + */ + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { From a8aeadb7d4a2dc308a75a50fdd8065f9a32ef336 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 28 Apr 2015 21:15:47 -0700 Subject: [PATCH 107/110] [SPARK-7208] [ML] [PYTHON] Added Matrix, SparseMatrix to __all__ list in linalg.py Added Matrix, SparseMatrix to __all__ list in linalg.py CC: mengxr Author: Joseph K. Bradley Closes #5759 from jkbradley/SPARK-7208 and squashes the following commits: deb51a2 [Joseph K. Bradley] Added Matrix, SparseMatrix to __all__ list in linalg.py --- python/pyspark/mllib/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index cc9a4cf8ba170..a57c0b3ae0d00 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -39,7 +39,8 @@ IntegerType, ByteType -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): From 5ef006fc4d010905e02cb905c9115b95ba55282b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 21:49:53 -0700 Subject: [PATCH 108/110] [SPARK-6756] [MLLIB] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector Add `compressed` to `Vector` with some other methods: `numActives`, `numNonzeros`, `toSparse`, and `toDense`. jkbradley Author: Xiangrui Meng Closes #5756 from mengxr/SPARK-6756 and squashes the following commits: 8d4ecbd [Xiangrui Meng] address comment and add mima excludes da54179 [Xiangrui Meng] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector --- .../apache/spark/mllib/linalg/Vectors.scala | 93 +++++++++++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 44 +++++++++ project/MimaExcludes.scala | 12 +++ 3 files changed, 149 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 34833e90d4af0..188d1e542b5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -116,6 +116,40 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) + + /** + * Number of active entries. An "active entry" is an element which is explicitly stored, + * regardless of its value. Note that inactive entries have value 0. + */ + def numActives: Int + + /** + * Number of nonzero elements. This scans all active values and count nonzeros. + */ + def numNonzeros: Int + + /** + * Converts this vector to a sparse vector with all explicit zeros removed. + */ + def toSparse: SparseVector + + /** + * Converts this vector to a dense vector. + */ + def toDense: DenseVector = new DenseVector(this.toArray) + + /** + * Returns a vector in either dense or sparse format, whichever uses less storage. + */ + def compressed: Vector = { + val nnz = numNonzeros + // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. + if (1.5 * (nnz + 1.0) < size) { + toSparse + } else { + toDense + } + } } /** @@ -525,6 +559,34 @@ class DenseVector(val values: Array[Double]) extends Vector { } result } + + override def numActives: Int = size + + override def numNonzeros: Int = { + // same as values.count(_ != 0.0) but faster + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } } object DenseVector { @@ -602,6 +664,37 @@ class SparseVector( } result } + + override def numActives: Int = values.length + + override def numNonzeros: Int = { + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + if (nnz == numActives) { + this + } else { + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0.0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + } } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 2839c4c289b2d..24755e9ff46fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -270,4 +270,48 @@ class VectorsSuite extends FunSuite { assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } + + test("Vector numActive and numNonzeros") { + val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv.numActives === 4) + assert(dv.numNonzeros === 2) + + val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv.numActives === 3) + assert(sv.numNonzeros === 2) + } + + test("Vector toSparse and toDense") { + val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv0.toDense === dv0) + val dv0s = dv0.toSparse + assert(dv0s.numActives === 2) + assert(dv0s === dv0) + + val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv0.toDense === sv0) + val sv0s = sv0.toSparse + assert(sv0s.numActives === 2) + assert(sv0s === sv0) + } + + test("Vector.compressed") { + val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0) + val dv0c = dv0.compressed.asInstanceOf[DenseVector] + assert(dv0c === dv0) + + val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0) + val dv1c = dv1.compressed.asInstanceOf[SparseVector] + assert(dv1 === dv1c) + assert(dv1c.numActives === 1) + + val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0)) + val sv0c = sv0.compressed.asInstanceOf[SparseVector] + assert(sv0 === sv0c) + assert(sv0c.numActives === 1) + + val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) + val sv1c = sv1.compressed.asInstanceOf[DenseVector] + assert(sv1 === sv1c) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 967961c2bf5c3..3beafa158eb97 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -76,6 +76,18 @@ object MimaExcludes { // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives") ) case v if v.startsWith("1.3") => From 271c4c621d91d3f610ae89e5d2e5dab1a2009ca6 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 28 Apr 2015 22:48:04 -0700 Subject: [PATCH 109/110] [SPARK-7215] made coalesce and repartition a part of the query plan Coalesce and repartition now show up as part of the query plan, rather than resulting in a new `DataFrame`. cc rxin Author: Burak Yavuz Closes #5762 from brkyvz/df-repartition and squashes the following commits: b1e76dd [Burak Yavuz] added documentation on repartitions 5807e35 [Burak Yavuz] renamed coalescepartitions fa4509f [Burak Yavuz] rename coalesce 2c349b5 [Burak Yavuz] address comments f2e6af1 [Burak Yavuz] add ticks 686c90b [Burak Yavuz] made coalesce and repartition a part of the query plan --- .../catalyst/plans/logical/basicOperators.scala | 11 +++++++++++ .../sql/catalyst/plans/logical/partitioning.scala | 8 +++++++- .../scala/org/apache/spark/sql/DataFrame.scala | 9 ++------- .../spark/sql/execution/SparkStrategies.scala | 5 +++-- .../spark/sql/execution/basicOperators.scala | 14 ++++++++++++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 6 +++--- 6 files changed, 40 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbc94a7ab3398..608e272da7784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -310,6 +310,17 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user + * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer + * of the output requires some specific ordering or distribution of the data. + */ +case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index e737418d9c3bc..63df2c1ee72ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -32,5 +32,11 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData -case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) +/** + * This method repartitions data using [[Expression]]s, and receives information about the + * number of partitions during execution. Used when a specific ordering or distribution is + * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * `coalesce` and `repartition`. + */ +case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ca6ae482eb2ab..2affba7d42cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -961,9 +961,7 @@ class DataFrame private[sql]( * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), - schema, needsConversion = false) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -974,10 +972,7 @@ class DataFrame private[sql]( * @group rdd */ override def coalesce(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.coalesce(numPartitions), - schema, - needsConversion = false) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 030ef118f75d4..3a0a6c86700a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -283,7 +283,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - + case logical.Repartition(numPartitions, shuffle, child) => + execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. @@ -317,7 +318,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil - case logical.Repartition(expressions, child) => + case logical.RepartitionByExpression(expressions, child) => execution.Exchange( HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d286fe81bee5f..1afdb409417ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -245,6 +245,20 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { } } +/** + * :: DeveloperApi :: + * Return a new RDD that has exactly `numPartitions` partitions. + */ +@DeveloperApi +case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).coalesce(numPartitions, shuffle) + } +} + /** * :: DeveloperApi :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0ea6d57b816c6..2dc6463abafa7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -783,13 +783,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case (None, Some(perPartitionOrdering), None, None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } From f98773a90ded0e408af6bbd85fafbaffbc5b825f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 28 Apr 2015 23:05:02 -0700 Subject: [PATCH 110/110] [SPARK-7205] Support `.ivy2/local` and `.m2/repositories/` in --packages In addition, I made a small change that will allow users to import 2 different artifacts with the same name. That change is made in `[organization]_[artifact]-[revision].[ext]`. This used to be only `[artifact].[ext]` which might have caused collisions between artifacts with the same artifactId, but different groupId's. cc pwendell Author: Burak Yavuz Closes #5755 from brkyvz/local-caches and squashes the following commits: c47c9c5 [Burak Yavuz] Small fixes to --packages --- .../org/apache/spark/deploy/SparkSubmit.scala | 34 ++++++++++++++----- .../spark/deploy/SparkSubmitUtilsSuite.scala | 27 ++++++++------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f4f572e1e256e..b8ae4af18d1d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -734,13 +734,31 @@ private[deploy] object SparkSubmitUtils { /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories + * @param ivySettings The Ivy settings for this session * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver cr.setName("list") + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + val m2Path = ".m2" + File.separator + "repository" + File.separator + localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new IBiblioResolver + localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, + "local" + File.separator).toURI.toString) + val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", + "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.setPattern(ivyPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + // the biblio resolver resolves POM declared dependencies val br: IBiblioResolver = new IBiblioResolver br.setM2compatible(true) @@ -773,8 +791,7 @@ private[deploy] object SparkSubmitUtils { /** * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath - * (will append to jars in SparkSubmit). The name of the jar is given - * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * (will append to jars in SparkSubmit). * @param artifacts Sequence of dependencies that were resolved and retrieved * @param cacheDirectory directory where jars are cached * @return a comma-delimited list of paths for the dependencies @@ -783,10 +800,9 @@ private[deploy] object SparkSubmitUtils { artifacts: Array[AnyRef], cacheDirectory: File): String = { artifacts.map { artifactInfo => - val artifactString = artifactInfo.toString - val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId cacheDirectory.getAbsolutePath + File.separator + - jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" }.mkString(",") } @@ -868,6 +884,7 @@ private[deploy] object SparkSubmitUtils { if (alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } @@ -877,7 +894,7 @@ private[deploy] object SparkSubmitUtils { // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos) + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) ivySettings.addResolver(repoResolver) ivySettings.setDefaultResolver(repoResolver.getName) @@ -911,7 +928,8 @@ private[deploy] object SparkSubmitUtils { } // retrieve all resolved dependencies ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", retrieveOptions.setConfs(Array(ivyConfName))) System.setOut(sysOut) resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8bcca926097a1..1b2b699cb11e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} +import org.apache.ivy.core.settings.IvySettings + import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -56,24 +58,23 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("create repo resolvers") { - val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + val settings = new IvySettings + val res1 = SparkSubmitUtils.createRepoResolvers(None, settings) // should have central and spark-packages by default - assert(resolver1.getResolvers.size() === 2) - assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") - assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + assert(res1.getResolvers.size() === 4) + assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") + assert(res1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "local-ivy-cache") + assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") + assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") val repos = "a/1,b/2,c/3" - val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) - assert(resolver2.getResolvers.size() === 5) + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) + assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => - if (i == 0) { - assert(resolver.getName === "central") - } else if (i == 1) { - assert(resolver.getName === "spark-packages") - } else { - assert(resolver.getName === s"repo-${i - 1}") - assert(resolver.getRoot === expected(i - 2)) + if (i > 3) { + assert(resolver.getName === s"repo-${i - 3}") + assert(resolver.getRoot === expected(i - 4)) } } }
Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system.
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}cpus: {submission.cores}, mem: {submission.mem}
{state.driverDescription.submissionId}{state.driverDescription.submissionDate}{state.driverDescription.command.mainClass}cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem}{state.startDate}{state.slaveId.getValue}{stateString(state.mesosTaskStatus)}
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}{submission.retryState.get.lastFailureStatus}{submission.retryState.get.nextRetry}{submission.retryState.get.retries}