diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd95e406f2a8e..009ed64775844 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -108,6 +108,14 @@ class SparkEnv ( pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } + + private[spark] + def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } } object SparkEnv extends Logging { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2b99b8a5af250..51b3e4d5e0936 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} /** @@ -41,7 +41,7 @@ class TaskContext( val attemptId: Long, val runningLocally: Boolean = false, private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable { + extends Serializable with Logging { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -103,8 +103,20 @@ class TaskContext( /** Marks the task as completed and triggers the listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true + val errorMsgs = new ArrayBuffer[String](2) // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } } /** Marks the task for interruption, i.e. cancellation. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8e178bc8480f7..791d853a015a1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.java +import java.io.Closeable import java.util import java.util.{Map => JMap} @@ -40,7 +41,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. */ -class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { +class JavaSparkContext(val sc: SparkContext) + extends JavaSparkContextVarargsWorkaround with Closeable { + /** * Create a JavaSparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -534,6 +537,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork sc.stop() } + override def close(): Unit = stop() + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ae8010300a500..ca8eef5f99edf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Try, Success, Failure} @@ -52,6 +53,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions = parent.partitions @@ -63,19 +65,26 @@ private[spark] class PythonRDD( val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars += ("SPARK_REUSE_WORKER" -> "1") + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) + var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - - // Cleanup the worker socket. This will also cause the Python worker to exit. - try { - worker.close() - } catch { - case e: Exception => logWarning("Failed to close worker socket", e) + if (reuse_worker && complete_cleanly) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + } else { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } } } @@ -133,6 +142,7 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + complete_cleanly = true null } } catch { @@ -195,11 +205,26 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - dataOut.writeInt(broadcastVars.length) + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- oldBids) { + if (!newBids.contains(bid)) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + oldBids.add(broadcast.id) + } } dataOut.flush() // Serialized command: @@ -207,17 +232,18 @@ private[spark] class PythonRDD( dataOut.write(command) // Data values PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) + worker.shutdownOutput() case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - } finally { - Try(worker.shutdownOutput()) // kill Python worker process + worker.shutdownOutput() } } } @@ -278,6 +304,14 @@ private object SpecialLengths { private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") + // remember the broadcasts sent to each worker + private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private def getWorkerBroadcasts(worker: Socket) = { + synchronized { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } + } + /** * Adapter for calling SparkContext#runJob from Python. * diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4c4796f6c59ba..71bdf0fe1b917 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 - var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val idleWorkers = new mutable.Queue[Socket]() + var lastActivity = 0L + new MonitorThread().start() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { + synchronized { + if (idleWorkers.size > 0) { + return idleWorkers.dequeue() + } + } createThroughDaemon() } else { createSimpleWorker() @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Monitor all the idle workers, kill them after timeout. + */ + private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + while (true) { + synchronized { + if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { + cleanupIdleWorkers() + lastActivity = System.currentTimeMillis() + } + } + Thread.sleep(10000) + } + } + } + + private def cleanupIdleWorkers() { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { + cleanupIdleWorkers() + // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stopWorker(worker: Socket) { - if (useDaemon) { - if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => - // tell daemon to kill worker by pid - val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) - output.flush() - daemon.getOutputStream.flush() + synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) } - } else { - simpleWorkers.get(worker).foreach(_.destroy()) } worker.close() } + + def releaseWorker(worker: Socket) { + if (useDaemon) { + synchronized { + lastActivity = System.currentTimeMillis() + idleWorkers.enqueue(worker) + } + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } } private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute } diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala new file mode 100644 index 0000000000000..f64e069cd1724 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Exception thrown when there is an exception in + * executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index db2ad829a48f9..faba5508c906c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler +import org.mockito.Mockito._ +import org.mockito.Matchers.any + import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} + class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - test("Calls executeOnCompleteCallbacks after failure") { + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { @@ -45,6 +49,20 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } assert(TaskContextSuite.completed === true) } + + test("all TaskCompletionListeners should be called even if some fail") { + val context = new TaskContext(0, 0, 0) + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + + intercept[TaskCompletionListenerException] { + context.markTaskCompleted() + } + + verify(listener, times(1)).onTaskCompletion(any()) + } } private object TaskContextSuite { diff --git a/docs/configuration.md b/docs/configuration.md index 67a3108b76914..036ef86e5cda5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -221,6 +221,15 @@ Apart from these, the following properties are also available, and may be useful The directory which is used to dump the profile result. The results will be dumped as sepereted file for each RDD. They can be loaded by ptats.Stats(). + + + spark.python.worker.reuse + true + + Reuse Python worker or not. If yes, it will use a fixed number of Python workers, + does not need to fork() a Python process for every tasks. It will be very useful + if there is large broadcast, then the broadcast will not be needed to transfered + from JVM to Python worker for every task. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d83efa4bab324..8d41fdec699e9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -918,7 +918,6 @@ options. ## Migration Guide for Shark User ### Scheduling -s To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: @@ -1110,7 +1109,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c The range of numbers is from `-9223372036854775808` to `9223372036854775807`. - `FloatType`: Represents 4-byte single-precision floating point numbers. - `DoubleType`: Represents 8-byte double-precision floating point numbers. - - `DecimalType`: + - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. * Binary type @@ -1232,7 +1231,7 @@ import org.apache.spark.sql._ scala.collection.Seq ArrayType(elementType, [containsNull])
- Note: The default value of containsNull is false. + Note: The default value of containsNull is true. @@ -1358,7 +1357,7 @@ please use factory methods provided in java.util.List DataType.createArrayType(elementType)
- Note: The value of containsNull will be false
+ Note: The value of containsNull will be true
DataType.createArrayType(elementType, containsNull). @@ -1505,7 +1504,7 @@ from pyspark.sql import * list, tuple, or array ArrayType(elementType, [containsNull])
- Note: The default value of containsNull is False. + Note: The default value of containsNull is True. diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 15445abf67147..64d6202acb27d 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -23,6 +23,7 @@ import sys import traceback import time +import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -46,17 +47,6 @@ def worker(sock): signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) - # Blocks until the socket is closed by draining the input stream - # until it raises an exception or returns EOF. - def waitSocketClose(sock): - try: - while True: - # Empty string is returned upon EOF (and only then). - if sock.recv(4096) == '': - return - except: - pass - # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. @@ -64,17 +54,13 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - # Acknowledge that the fork was successful - write_int(os.getpid(), outfile) - outfile.flush() worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) + if exit_code: + os._exit(exit_code) # Cleanup zombie children @@ -111,6 +97,8 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + reuse = os.environ.get("SPARK_REUSE_WORKER") + # Initialization complete try: while True: @@ -163,7 +151,19 @@ def handle_sigterm(*args): # in child process listen_sock.close() try: - worker(sock) + # Acknowledge that the fork was successful + outfile = sock.makefile("w") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + while True: + worker(sock) + if not reuse: + # wait for closing + while sock.recv(1024): + pass + break + gc.collect() except: traceback.print_exc() os._exit(1) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index bb60d3d0c8463..68f6033616726 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -21,7 +21,7 @@ from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import Serializer +from pyspark.serializers import FramedSerializer """ @@ -451,18 +451,16 @@ def _serialize_rating(r): return ba -class RatingDeserializer(Serializer): +class RatingDeserializer(FramedSerializer): - def loads(self, stream): - length = struct.unpack("!i", stream.read(4))[0] - ba = stream.read(length) - res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4) + def loads(self, string): + res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) return int(res[0]), int(res[1]), res[2] def load_stream(self, stream): while True: try: - yield self.loads(stream) + yield self._read_with_length(stream) except struct.error: return except EOFError: diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a5f9341e819a9..ec3c6f055441d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream): def _read_with_length(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError obj = stream.read(length) if obj == "": raise EOFError @@ -438,6 +440,8 @@ def __init__(self, use_unicode=False): def loads(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c98c56293ee50..933c611cc63d3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1247,11 +1247,46 @@ def run(): except OSError: self.fail("daemon had been killed") + # run a normal job + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + def test_fd_leak(self): N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(range(100), 1) + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write("Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + class TestSparkSubmit(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6e5c1e17d2647..2985ed2767c57 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -72,9 +72,14 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) - + if bid >= 0: + value = ser._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + else: + bid = - bid - 1 + _broadcastRegistry.remove(bid) + + _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, stats, deserializer, serializer) = command init_time = time.time() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 088f11ee4aa53..9cbab3d5d0d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -171,7 +171,7 @@ final class MutableByte extends MutableValue { } final class MutableAny extends MutableValue { - var value: Any = 0 + var value: Any = _ def boxed = if (isNull) null else value def update(v: Any) = value = { isNull = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 42a5a9a84f362..c9faf0852142a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( def hasNext = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int) { - columnType.setField(row, ordinal, extractSingle(buffer)) + def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) } - def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } protected def underlyingBuffer = buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index b3ec5ded22422..2e61a981375aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) - columnType.append(field, buffer) + override def appendFrom(row: Row, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) } override def build() = { @@ -142,16 +141,16 @@ private[sql] object ColumnBuilder { useCompression: Boolean = false): ColumnBuilder = { val builder = (typeId match { - case INT.typeId => new IntColumnBuilder - case LONG.typeId => new LongColumnBuilder - case FLOAT.typeId => new FloatColumnBuilder - case DOUBLE.typeId => new DoubleColumnBuilder - case BOOLEAN.typeId => new BooleanColumnBuilder - case BYTE.typeId => new ByteColumnBuilder - case SHORT.typeId => new ShortColumnBuilder - case STRING.typeId => new StringColumnBuilder - case BINARY.typeId => new BinaryColumnBuilder - case GENERIC.typeId => new GenericColumnBuilder + case INT.typeId => new IntColumnBuilder + case LONG.typeId => new LongColumnBuilder + case FLOAT.typeId => new FloatColumnBuilder + case DOUBLE.typeId => new DoubleColumnBuilder + case BOOLEAN.typeId => new BooleanColumnBuilder + case BYTE.typeId => new ByteColumnBuilder + case SHORT.typeId => new ShortColumnBuilder + case STRING.typeId => new StringColumnBuilder + case BINARY.typeId => new BinaryColumnBuilder + case GENERIC.typeId => new GenericColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index fc343ccb995c2..203a714e03c97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats { var lower = Byte.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value @@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats { var lower = Short.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value @@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats { var lower = Long.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value @@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { var lower = Double.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value @@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats { var lower = Float.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value @@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats { var lower = Int.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value @@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats { var lower: String = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value @@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { var lower: Timestamp = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9a61600115872..198b5756676aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import java.sql.Timestamp - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ @@ -46,16 +45,33 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def extract(buffer: ByteBuffer): JvmType + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + /** * Appends the given value v of type T into the given ByteBuffer. */ - def append(v: JvmType, buffer: ByteBuffer) + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } /** - * Returns the size of the value. This is used to calculate the size of variable length types - * such as byte arrays and strings. + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. */ - def actualSize(v: JvmType): Int = defaultSize + def actualSize(row: Row, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs @@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType) + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to(toOrdinal) = from(fromOrdinal) + } /** * Creates a duplicated copy of the value. @@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer) { + def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + def extract(buffer: ByteBuffer) = { buffer.getInt() } - override def setField(row: MutableRow, ordinal: Int, value: Int) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, buffer.getInt()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { - override def append(v: Long, buffer: ByteBuffer) { + override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getLong() } - override def setField(row: MutableRow, ordinal: Int, value: Long) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, buffer.getLong()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { - override def append(v: Float, buffer: ByteBuffer) { + override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getFloat() } - override def setField(row: MutableRow, ordinal: Int, value: Float) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, buffer.getFloat()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { - override def append(v: Double, buffer: ByteBuffer) { + override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getDouble() } - override def setField(row: MutableRow, ordinal: Int, value: Double) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, buffer.getDouble()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { - override def append(v: Boolean, buffer: ByteBuffer) { - buffer.put(if (v) 1.toByte else 0.toByte) + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } override def extract(buffer: ByteBuffer) = buffer.get() == 1 - override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { - override def append(v: Byte, buffer: ByteBuffer) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.get() } - override def setField(row: MutableRow, ordinal: Int, value: Byte) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { - override def append(v: Short, buffer: ByteBuffer) { + override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getShort() } - override def setField(row: MutableRow, ordinal: Int, value: Short) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4 + override def actualSize(row: Row, ordinal: Int): Int = { + row.getString(ordinal).getBytes("utf-8").length + 4 + } - override def append(v: String, buffer: ByteBuffer) { + override def append(v: String, buffer: ByteBuffer): Unit = { val stringBytes = v.getBytes("utf-8") buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } @@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { new String(stringBytes, "utf-8") } - override def setField(row: MutableRow, ordinal: Int, value: String) { + override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { row.setString(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setString(toOrdinal, from.getString(fromOrdinal)) + } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { @@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { timestamp } - override def append(v: Timestamp, buffer: ByteBuffer) { + override def append(v: Timestamp, buffer: ByteBuffer): Unit = { buffer.putLong(v.getTime).putInt(v.getNanos) } @@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { row(ordinal).asInstanceOf[Timestamp] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp) { + override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } } @@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(v: Array[Byte]) = v.length + 4 + override def actualSize(row: Row, ordinal: Int) = { + getField(row, ordinal).length + 4 + } - override def append(v: Array[Byte], buffer: ByteBuffer) { + override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { buffer.putInt(v.length).put(v, 0, v.length) } @@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eab2f23c18e1..8a3612cdf19be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { baseIterator => + val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => @@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation( ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray - var row: Row = null var rowCount = 0 - - while (baseIterator.hasNext && rowCount < batchSize) { - row = baseIterator.next() + while (rowIterator.hasNext && rowCount < batchSize) { + val row = rowIterator.next() var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build()), stats) } - def hasNext = baseIterator.hasNext + def hasNext = rowIterator.hasNext } }.cache() @@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan( } } + // Accumulators used for testing purposes val readPartitions = sparkContext.accumulator(0) val readBatches = sparkContext.accumulator(0) @@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan( readPartitions.setValue(0) readBatches.setValue(0) - relation.cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), relation.partitionStatistics.schema) - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = if (attributes.isEmpty) { - Seq(0) + // Find the ordinals and data types of the requested columns. If none are requested, use the + // narrowest (the field with minimum default element size). + val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { + val (narrowestOrdinal, narrowestDataType) = + relation.output.zipWithIndex.map { case (a, ordinal) => + ordinal -> a.dataType + } minBy { case (_, dataType) => + ColumnType(dataType).defaultSize + } + Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + attributes.map { a => + relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + }.unzip } - val rows = iterator - // Skip pruned batches - .filter { cachedBatch => - if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") - logInfo(s"Skipping partition based on stats $statsString") - false - } else { - readBatches += 1 - true - } - } - // Build column accessors - .map { cachedBatch => - requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) - } - // Extract rows via column accessors - .flatMap { columnAccessors => - val nextRow = new GenericMutableRow(columnAccessors.length) + val nextRow = new SpecificMutableRow(requestedColumnDataTypes) + + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + val rows = cacheBatches.flatMap { cachedBatch => + // Build column accessors + val columnAccessors = + requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + + // Extract rows via column accessors new Iterator[Row] { override def next() = { var i = 0 @@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors(0).hasNext } } - if (rows.hasNext) { - readPartitions += 1 + if (rows.hasNext) { + readPartitions += 1 + } + + rows } - rows + // Do partition batch pruning if enabled + val cachedBatchesToScan = + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + } else { + cachedBatchIterator + } + + cachedBatchesToRows(cachedBatchesToScan) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index b7f8826861a2c..965782a40031b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override protected def initialize() { + abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 @@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int) { + abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index a72970eef7aa4..f1f494ac26d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { protected var nullCount: Int = _ private var pos: Int = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + nulls = ByteBuffer.allocate(1024) nulls.order(ByteOrder.nativeOrder()) pos = 0 @@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index b4120a3d4368b..27ac5f4dbdbbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar.compression -import java.nio.ByteBuffer - +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} @@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc abstract override def hasNext = super.hasNext || decoder.hasNext - override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() + override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + decoder.next(row, ordinal) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index a5826bb033e41..628d9cec41d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] var compressionEncoders: Seq[Encoder[T]] = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + compressionEncoders = if (useCompression) { - schemes.filter(_.supports(columnType)).map(_.encoder[T]) + schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType)) } else { - Seq(PassThrough.encoder) + Seq(PassThrough.encoder(columnType)) } super.initialize(initialSize, columnName, useCompression) } @@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - + private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(field, columnType) + compressionEncoders(i).gatherCompressibilityStats(row, ordinal) i += 1 } } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) @@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) - if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType) } // Header = column type ID + null count + null positions @@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] .putInt(nullCount) .put(nulls) - logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(nonNullBuffer, compressedBuffer, columnType) + logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 7797f75177893..acb06cb5376b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.columnar.compression -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} private[sql] trait Encoder[T <: NativeType] { - def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {} + def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] { if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 } - def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer + def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] +private[sql] trait Decoder[T <: NativeType] { + def next(row: MutableRow, ordinal: Int): Unit + + def hasNext: Boolean +} private[sql] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType]: Encoder[T] + def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8cf9ec74ca2de..29edcf17242c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -23,7 +23,8 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar._ import org.apache.spark.util.Utils @@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]) = true - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize = 0 override def compressedSize = 0 - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme { class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next() = columnType.extract(buffer) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } override def hasNext = buffer.hasRemaining } @@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) @@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new GenericMutableRow(1) + private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize = _uncompressedSize override def compressedSize = _compressedSize - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { - val actualSize = columnType.actualSize(value) + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize if (lastValue.isNullAt(0)) { - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 _compressedSize += actualSize + 4 } else { @@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { lastRun += 1 } else { _compressedSize += actualSize + 4 - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 } } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - var currentValue = columnType.extract(from) + val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) var currentRun = 1 + val value = new SpecificMutableRow(Seq(columnType.dataType)) + + columnType.extract(from, currentValue, 0) while (from.hasRemaining) { - val value = columnType.extract(from) + columnType.extract(from, value, 0) - if (value == currentValue) { + if (value.head == currentValue.head) { currentRun += 1 } else { // Writes current run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) // Resets current run - currentValue = value + columnType.copyField(value, 0, currentValue, 0) currentRun = 1 } } // Writes the last run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) } @@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#JvmType = _ - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = buffer.getInt() @@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { valueCount += 1 } - currentValue + columnType.setField(row, ordinal, currentValue) } override def hasNext = valueCount < run || buffer.hasRemaining @@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def supports(columnType: ColumnType[_, _]) = columnType match { case INT | LONG | STRING => true case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + if (!overflow) { - val actualSize = columnType.actualSize(value) + val actualSize = columnType.actualSize(row, ordinal) count += 1 _uncompressedSize += actualSize @@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def next() = dictionary(buffer.getShort()) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.setField(row, ordinal, dictionary(buffer.getShort())) + } override def hasNext = buffer.hasRemaining } @@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new this.Encoder).asInstanceOf[compression.Encoder[T]] + } override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats( - value: Boolean, - columnType: NativeColumnType[BooleanType.type]) { - + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } - override def compress( - from: ByteBuffer, - to: ByteBuffer, - columnType: NativeColumnType[BooleanType.type]) = { - + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(): Boolean = { + override def next(row: MutableRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme { currentWord = buffer.getLong() } - ((currentWord >> bit) & 1) != 0 + row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) } override def hasNext: Boolean = visited < count } } -private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme { +private[sql] case object IntDelta extends CompressionScheme { + override def typeId: Int = 4 + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { - new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]]) - .asInstanceOf[compression.Decoder[T]] + new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] - - /** - * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or - * `(false, 0: Byte)` otherwise. - */ - protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte) + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - /** - * Simply computes `x + delta` - */ - protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType + override def supports(columnType: ColumnType[_, _]) = columnType == INT - class Encoder extends compression.Encoder[I] { - private var _compressedSize: Int = 0 + class Encoder extends compression.Encoder[IntegerType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 - private var _uncompressedSize: Int = 0 + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize - private var prev: I#JvmType = _ + private var prevValue: Int = _ - private var initial = true + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getInt(ordinal) + val delta = value - prevValue - override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) { - _uncompressedSize += columnType.defaultSize + _compressedSize += 1 - if (initial) { - initial = false - _compressedSize += 1 + columnType.defaultSize - } else { - val (smallEnough, _) = byteSizedDelta(value, prev) - _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize) + // If this is the first integer to be compressed, or the delta is out of byte range, then give + // up compressing this integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += INT.defaultSize } - prev = value + _uncompressedSize += INT.defaultSize + prevValue = value } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(typeId) if (from.hasRemaining) { - var prev = columnType.extract(from) + var prev = from.getInt() to.put(Byte.MinValue) - columnType.append(prev, to) + to.putInt(prev) while (from.hasRemaining) { - val current = columnType.extract(from) - val (smallEnough, delta) = byteSizedDelta(current, prev) + val current = from.getInt() + val delta = current - prev prev = current - if (smallEnough) { - to.put(delta) + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) } else { to.put(Byte.MinValue) - columnType.append(current, to) + to.putInt(current) } } } - to.rewind() - to + to.rewind().asInstanceOf[ByteBuffer] } - - override def uncompressedSize = _uncompressedSize - - override def compressedSize = _compressedSize } - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I]) - extends compression.Decoder[I] { + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type]) + extends compression.Decoder[IntegerType.type] { + + private var prev: Int = _ - private var prev: I#JvmType = _ + override def hasNext: Boolean = buffer.hasRemaining - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer) - prev + prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + row.setInt(ordinal, prev) } - - override def hasNext = buffer.hasRemaining } } -private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] { - override val typeId = 4 +private[sql] case object LongDelta extends CompressionScheme { + override def typeId: Int = 5 - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - override protected def addDelta(x: Int, delta: Byte) = x + delta + override def supports(columnType: ColumnType[_, _]) = columnType == LONG + + class Encoder extends compression.Encoder[LongType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 + + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize + + private var prevValue: Long = _ + + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getLong(ordinal) + val delta = value - prevValue + + _compressedSize += 1 - override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + // If this is the first long integer to be compressed, or the delta is out of byte range, then + // give up compressing this long integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += LONG.defaultSize + } + + _uncompressedSize += LONG.defaultSize + prevValue = value + } + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + to.putInt(typeId) + + if (from.hasRemaining) { + var prev = from.getLong() + to.put(Byte.MinValue) + to.putLong(prev) + + while (from.hasRemaining) { + val current = from.getLong() + val delta = current - prev + prev = current + + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) + } else { + to.put(Byte.MinValue) + to.putLong(current) + } + } + } + + to.rewind().asInstanceOf[ByteBuffer] + } } -} -private[sql] case object LongDelta extends IntegralDelta[LongType.type] { - override val typeId = 5 + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type]) + extends compression.Decoder[LongType.type] { - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + private var prev: Long = _ - override protected def addDelta(x: Long, delta: Byte) = x + delta + override def hasNext: Boolean = buffer.hasRemaining - override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + override def next(row: MutableRow, ordinal: Int): Unit = { + val delta = buffer.get() + prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + row.setLong(ordinal, prev) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f2389f8f0591e..265b67737c475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,13 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index cde91ceb68c98..0cdbb3167ce36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite { def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row) { + initialStatistics: Row): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 75f653f3280bd..4fb1ecf1d532b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging { def checkActualSize[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], value: JvmType, - expected: Int) { + expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - columnType.actualSize(value) + val row = new GenericMutableRow(1) + columnType.setField(row, 0, value) + columnType.actualSize(row, 0) } } @@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType) { + getter: (ByteBuffer) => T#JvmType): Unit = { testColumnType[T, T#JvmType](columnType, putter, getter) } @@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType) { + getter: (ByteBuffer) => JvmType): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val seq = (0 until 4).map(_ => makeRandomValue(columnType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 0e3c67f5eed29..c1278248ef655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLConf, QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 3baa6f8ec0c83..6c9a9ab6c3418 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnAccessor[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index a77262534a352..f54a21eb4fbb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 5d2fd4959197c..69e0adbd3ee0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning - override protected def beforeAll() { + override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) @@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } - override protected def afterAll() { + override protected def afterAll(): Unit = { setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } @@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be filter: String, expectedQueryResult: Seq[Int], expectedReadPartitions: Int, - expectedReadBatches: Int) { + expectedReadBatches: Int): Unit = { test(filter) { val query = sql(s"SELECT * FROM intData WHERE $filter") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index e01cc8b4d20f2..d9e488e0ffd16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val mutableRow = new GenericMutableRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + mutableRow.getBoolean(0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d2969d906c943..1cdb909146d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) // 4 extra bytes for dictionary size - val dictionarySize = 4 + values.map(columnType.actualSize).sum + val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum // 2 bytes for each `Short` val compressedSize = 4 + dictionarySize + 2 * inputSeq.length // 4 extra bytes for compression scheme type ID @@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 322f447c24840..73f31c0233343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite { def testIntegralDelta[I <: IntegralType]( columnStats: ColumnStats, columnType: NativeColumnType[I], - scheme: IntegralDelta[I]) { + scheme: CompressionScheme) { def skeleton(input: Seq[I#JvmType]) { // ------------- @@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) + if (input.nonEmpty) { input.foreach{ assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 218c09ac26362..4ce2552112c92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite { // Compression scheme ID + compressed contents val compressedSize = 4 + inputRuns.map { case (index, _) => // 4 extra bytes each run for run length - columnType.actualSize(values(index)) + 4 + columnType.actualSize(rows(index), 0) + 4 }.sum // 4 extra bytes for compression scheme type ID @@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b0a06cd3ca090..08f7358446b29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -58,8 +58,7 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean, - binaryField: Array[Byte]) + booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -70,13 +69,14 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, - binaryField: Array[Byte], array: Seq[Int], arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], data: Data) +case class BinaryData(binaryData: Array[Byte]) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -108,26 +108,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray)) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - } + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } - test("Treat binary as string") { + test("read/write binary data") { + // Since equality for Array[Byte] is broken we test this separately. + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) + parquetFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + } + + ignore("Treat binary as string") { val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString // Create the test file. @@ -142,37 +142,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA StructField("c2", BinaryType, false) :: Nil) val schemaRDD1 = applySchema(rowRDD, schema) schemaRDD1.saveAsParquetFile(path) - val resultWithBinary = parquetFile(path).collect - range.foreach { - i => - assert(resultWithBinary(i).getInt(0) === i) - assert(resultWithBinary(i)(1) === s"val_$i".getBytes) - } - - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - // This ParquetRelation always use Parquet types to derive output. - val parquetRelation = new ParquetRelation( - path.toString, - Some(TestSQLContext.sparkContext.hadoopConfiguration), - TestSQLContext) { - override val output = - ParquetTypesConverter.convertToAttributes( - ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, - TestSQLContext.isParquetBinaryAsString) - } - val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) - val resultWithString = schemaRDD.collect - range.foreach { - i => - assert(resultWithString(i).getInt(0) === i) - assert(resultWithString(i)(1) === s"val_$i") - } + checkAnswer( + parquetFile(path).select('c1, 'c2.cast(StringType)), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - schemaRDD.registerTempTable("tmp") + setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + parquetFile(path).printSchema() checkAnswer( - sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + parquetFile(path), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) + // Set it back. TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) @@ -275,34 +254,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) + val data = sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), Data((0 until x), Nested(x, s"$x")))) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) - assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) - assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) - assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) - } + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } test("self-join parquet files") { @@ -399,23 +363,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } - test("Saving case class RDD table to file and reading it back in") { - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") - } - Utils.deleteRecursively(file) - } - test("Read a parquet file instead of a directory") { val file = getTempFilePath("parquet") val path = file.toString @@ -448,32 +395,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) - assert(rdd_copy1(0).apply(0) === 1) - assert(rdd_copy1(0).apply(1) === "val_1") - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) assert(rdd_copy2.size === 200) - assert(rdd_copy2(0).apply(0) === 1) - assert(rdd_copy2(0).apply(1) === "val_1") - assert(rdd_copy2(99).apply(0) === 100) - assert(rdd_copy2(99).apply(1) === "val_100") - assert(rdd_copy2(100).apply(0) === 1) - assert(rdd_copy2(100).apply(1) === "val_1") Utils.deleteRecursively(dirname) } test("Insert (appending) to same table via Scala API") { - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) assert(double_rdd.size === 30) - for(i <- (0 to 14)) { - assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") - } + // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) ParquetTestData.writeFile() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 329f80cad471e..84fafcde63d05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector - +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.expressions._ /** * A trait for subclasses that handle table scans. @@ -108,12 +106,12 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new GenericMutableRow(attrsWithIndex.length) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } @@ -164,33 +162,32 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer - val mutableRow = new GenericMutableRow(attributes.length) - - // split the attributes (output schema) into 2 categories: - // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the - // index of the attribute in the output Row. - val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { - relation.partitionKeys.indexOf(attr._1) >= 0 - }) - - def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { - partitionKeys.foreach { case (attr, ordinal) => - // get partition key ordinal for a given attribute - val partOridinal = relation.partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + + // Splits all attributes into two groups, partition key attributes and those that are not. + // Attached indices indicate the position of each attribute in the output schema. + val (partitionKeyAttrs, nonPartitionKeyAttrs) = + attributes.zipWithIndex.partition { case (attr, _) => + relation.partitionKeys.contains(attr) + } + + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + partitionKeyAttrs.foreach { case (attr, ordinal) => + val partOrdinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } - // fill the partition key for the given MutableRow Object + + // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) - hivePartitionRDD.mapPartitions { iter => + createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) } }.toSeq @@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors { } /** - * Transform the raw data(Writable object) into the Row object for an iterable input - * @param iter Iterable input which represented as Writable object - * @param deserializer Deserializer associated with the input writable object - * @param attrs Represents the row attribute names and its zero-based position in the MutableRow - * @param row reusable MutableRow object - * - * @return Iterable Row object that transformed from the given iterable input. + * Transform all given raw `Writable`s into `Row`s. + * + * @param iterator Iterator of all `Writable`s to be transformed + * @param deserializer The `Deserializer` associated with the input `Writable` + * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding + * positions in the output schema + * @param mutableRow A reusable `MutableRow` that should be filled + * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( - iter: Iterator[Writable], + iterator: Iterator[Writable], deserializer: Deserializer, - attrs: Seq[(Attribute, Int)], - row: GenericMutableRow): Iterator[Row] = { + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - // get the field references according to the attributes(output of the reader) required - val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => + soi.getStructFieldRef(attr.name) -> ordinal + }.unzip + + // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern + // matching and branching costs per row. + val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi) + } + } // Map each tuple to a row object - iter.map { value => + iterator.map { value => val raw = deserializer.deserialize(value) - var idx = 0; - while (idx < fieldRefs.length) { - val fieldRef = fieldRefs(idx)._1 - val fieldIdx = fieldRefs(idx)._2 - val fieldValue = soi.getStructFieldData(raw, fieldRef) - - row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) - - idx += 1 + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - row: Row + mutableRow: Row } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a3bfd3a8f1fd2..70fb15259e7d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ object TestHive - extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) /** * A locally running test instance of Spark's Hive execution engine. @@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 671c3b162f875..79cc7a3fcc7d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -250,9 +250,9 @@ abstract class HiveComparisonTest } try { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") - if (reset) { TestHive.reset() } + if (reset) { + TestHive.reset() + } val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6bf8d18a5c32c..8c8a8b124ac69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = sql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + + val expected = sql("SELECT key FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + assert(actual === expected) } @@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { - case Row(key: String, value: String) => key -> value + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala index 0723be7298e15..e380280f301c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -20,14 +20,10 @@ package org.apache.spark.sql.parquet import java.io.File -import org.apache.spark.sql.hive.execution.HiveTableScan import org.scalatest.BeforeAndAfterAll -import scala.reflect.ClassTag - -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ case class ParquetData(intField: Int, stringField: String) @@ -36,27 +32,19 @@ case class ParquetData(intField: Int, stringField: String) * Tests for our SerDe -> Native parquet scan conversion. */ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { - override def beforeAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - val partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - sql(s""" + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" create external table partitioned_parquet ( intField INT, @@ -70,7 +58,7 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${partitionedTableDir.getCanonicalPath}' """) - sql(s""" + sql(s""" create external table normal_parquet ( intField INT, @@ -83,8 +71,15 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' """) - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") } test("project the partitioning column") { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 18605cac7006c..9dc26dc6b32a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ package org.apache.spark.streaming.api.java import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.io.InputStream +import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} import akka.actor.{Props, SupervisorStrategy} @@ -49,7 +49,7 @@ import org.apache.spark.streaming.receiver.Receiver * respectively. `context.awaitTransformation()` allows the current thread to wait for the * termination of a context by `stop()` or by an exception. */ -class JavaStreamingContext(val ssc: StreamingContext) { +class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a StreamingContext. @@ -540,6 +540,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { ssc.stop(stopSparkContext, stopGracefully) } + + override def close(): Unit = stop() + } /**