diff --git a/bin/spark-class b/bin/spark-class index e8201c18d52de..91d858bc063d0 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem This is the entry point for running Spark shell. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell +cmd /V /E /C %~dp0spark-shell2.cmd %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd new file mode 100644 index 0000000000000..2ee60b4e2a2b3 --- /dev/null +++ b/bin/spark-shell2.cmd @@ -0,0 +1,22 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +set SPARK_HOME=%~dp0.. + +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index cf6046d1547ad..8f3b84c7b971d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* - -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) - shift -goto loop -:continue - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +cmd /V /E /C %~dp0spark-submit2.cmd %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd new file mode 100644 index 0000000000000..cf6046d1547ad --- /dev/null +++ b/bin/spark-submit2.cmd @@ -0,0 +1,68 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + +set SPARK_HOME=%~dp0.. +set ORIG_ARGS=%* + +rem Reset the values of all variables used +set SPARK_SUBMIT_DEPLOY_MODE=client +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf +set SPARK_SUBMIT_DRIVER_MEMORY= +set SPARK_SUBMIT_LIBRARY_PATH= +set SPARK_SUBMIT_CLASSPATH= +set SPARK_SUBMIT_OPTS= +set SPARK_SUBMIT_BOOTSTRAP_DRIVER= + +:loop +if [%1] == [] goto continue + if [%1] == [--deploy-mode] ( + set SPARK_SUBMIT_DEPLOY_MODE=%2 + ) else if [%1] == [--properties-file] ( + set SPARK_SUBMIT_PROPERTIES_FILE=%2 + ) else if [%1] == [--driver-memory] ( + set SPARK_SUBMIT_DRIVER_MEMORY=%2 + ) else if [%1] == [--driver-library-path] ( + set SPARK_SUBMIT_LIBRARY_PATH=%2 + ) else if [%1] == [--driver-class-path] ( + set SPARK_SUBMIT_CLASSPATH=%2 + ) else if [%1] == [--driver-java-options] ( + set SPARK_SUBMIT_OPTS=%2 + ) + shift +goto loop +:continue + +rem For client mode, the driver will be launched in the same JVM that launches +rem SparkSubmit, so we may need to read the properties file for any extra class +rem paths, library paths, java options and memory early on. Otherwise, it will +rem be too late by the time the driver JVM has started. + +if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( + if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( + rem Parse the properties file only if the special configs exist + for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ + %SPARK_SUBMIT_PROPERTIES_FILE%') do ( + set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + ) + ) +) + +cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 4e6d708af0ea7..2d998d4c7a5d9 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -18,131 +18,55 @@ package org.apache.spark; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import scala.Function0; import scala.Function1; import scala.Unit; -import scala.collection.JavaConversions; import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.util.TaskCompletionListener; -import org.apache.spark.util.TaskCompletionListenerException; /** -* :: DeveloperApi :: -* Contextual information about a task which can be read or mutated during execution. -*/ -@DeveloperApi -public class TaskContext implements Serializable { - - private int stageId; - private int partitionId; - private long attemptId; - private boolean runningLocally; - private TaskMetrics taskMetrics; - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - } - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task use + * TaskContext.get(). + */ +public abstract class TaskContext implements Serializable { /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); + public static TaskContext get() { + return taskContext.get(); } private static ThreadLocal taskContext = new ThreadLocal(); - /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ - public static void setTaskContext(TaskContext tc) { + static void setTaskContext(TaskContext tc) { taskContext.set(tc); } - public static TaskContext get() { - return taskContext.get(); - } - - /** :: Internal API :: */ - public static void unset() { + static void unset() { taskContext.remove(); } - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile boolean interrupted = false; - - // Whether the task has completed. - private volatile boolean completed = false; - /** - * Checks whether the task has completed. + * Whether the task has completed. */ - public boolean isCompleted() { - return completed; - } + public abstract boolean isCompleted(); /** - * Checks whether the task has been killed. + * Whether the task has been killed. */ - public boolean isInterrupted() { - return interrupted; - } + public abstract boolean isInterrupted(); + + /** @deprecated: use isRunningLocally() */ + @Deprecated + public abstract boolean runningLocally(); + + public abstract boolean isRunningLocally(); /** * Add a (Java friendly) listener to be executed on task completion. @@ -150,10 +74,7 @@ public boolean isInterrupted() { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { - onCompleteCallbacks.add(listener); - return this; - } + public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); /** * Add a listener in the form of a Scala closure to be executed on task completion. @@ -161,109 +82,27 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } + public abstract TaskContext addTaskCompletionListener(final Function1 f); /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * Deprecated: use addTaskCompletionListener - * + * @deprecated: use addTaskCompletionListener + * * @param f Callback function. */ @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } - - /** - * ::Internal API:: - * Marks the task as completed and triggers the listeners. - */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl: revlist) { - try { - tcl.onTaskCompletion(this); - } catch (Throwable e) { - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** - * ::Internal API:: - * Marks the task for interruption, i.e. cancellation. - */ - public void markInterrupted() { - interrupted = true; - } - - @Deprecated - /** Deprecated: use getStageId() */ - public int stageId() { - return stageId; - } - - @Deprecated - /** Deprecated: use getPartitionId() */ - public int partitionId() { - return partitionId; - } - - @Deprecated - /** Deprecated: use getAttemptId() */ - public long attemptId() { - return attemptId; - } - - @Deprecated - /** Deprecated: use isRunningLocally() */ - public boolean runningLocally() { - return runningLocally; - } - - public boolean isRunningLocally() { - return runningLocally; - } + public abstract void addOnCompleteCallback(final Function0 f); - public int getStageId() { - return stageId; - } + public abstract int stageId(); - public int getPartitionId() { - return partitionId; - } + public abstract int partitionId(); - public long getAttemptId() { - return attemptId; - } + public abstract long attemptId(); - /** ::Internal API:: */ - public TaskMetrics taskMetrics() { - return taskMetrics; - } + /** ::DeveloperApi:: */ + @DeveloperApi + public abstract TaskMetrics taskMetrics(); } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 396cdd1247e07..dd3157990ef2d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import java.io._ import java.net.URI +import java.util.Arrays import java.util.concurrent.atomic.AtomicInteger import java.util.{Properties, UUID} import java.util.UUID.randomUUID @@ -237,7 +238,6 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } - ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) @@ -341,6 +341,10 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() postApplicationStart() + // Bind the SparkUI after starting the task scheduler + // because certain pages and listeners depend on it + ui.foreach(_.bind()) + private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack @@ -814,6 +818,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + val callSite = getCallSite + logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } @@ -1429,7 +1435,10 @@ object SparkContext extends Logging { simpleWritableConverter[Boolean, BooleanWritable](_.get) implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) } implicit def stringWritableConverter(): WritableConverter[String] = diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala new file mode 100644 index 0000000000000..4636c4600a01a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This class exists to restrict the visibility of TaskContext setters. + */ +private [spark] object TaskContextHelper { + + def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + + def unset(): Unit = TaskContext.unset() + +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala new file mode 100644 index 0000000000000..afd2b85d33a77 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} + +import scala.collection.mutable.ArrayBuffer + +private[spark] class TaskContextImpl(val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext + with Logging { + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + + // Whether the corresponding task has been killed. + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } + + @deprecated("use addTaskCompletionListener", "1.1.0") + override def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } + } + + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { + completed = true + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true + } + + override def isCompleted: Boolean = completed + + override def isRunningLocally: Boolean = runningLocally + + override def isInterrupted: Boolean = interrupted +} + diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index c74f86548ef85..29ca751519abd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,10 +23,9 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -42,7 +41,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -55,9 +54,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -234,7 +233,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { @@ -748,6 +747,7 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler + SerDeUtil.initialize() iter.flatMap { row => unpickle.loads(row) match { // in case of objects are pickled in batch mode @@ -787,7 +787,7 @@ private[spark] object PythonRDD extends Logging { }.toJavaRDD() } - private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { private val pickle = new Pickler() private var batch = 1 private val buffer = new mutable.ArrayBuffer[Any] @@ -824,11 +824,12 @@ private[spark] object PythonRDD extends Logging { */ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { pyRDD.rdd.mapPartitions { iter => + SerDeUtil.initialize() val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]] + obj.asInstanceOf[JArrayList[_]].asScala } else { Seq(obj) } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 7903457b17e13..ebdc3533e0992 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ -private[python] object SerDeUtil extends Logging { +private[spark] object SerDeUtil extends Logging { // Unpickle array.array generated by Python 2.6 class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { // /* Description of types */ @@ -76,9 +76,18 @@ private[python] object SerDeUtil extends Logging { } } + private var initialized = false + // This should be called before trying to unpickle array.array from Python + // In cluster mode, this should be put in closure def initialize() = { - Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + synchronized{ + if (!initialized) { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + initialized = true + } + } } + initialize() private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler @@ -143,6 +152,7 @@ private[python] object SerDeUtil extends Logging { obj.asInstanceOf[Array[_]].length == 2 } pyRDD.mapPartitions { iter => + initialize() val unpickle = new Unpickler val unpickled = if (batchSerialized) { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57b251ff47714..72a452e0aefb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,14 +17,11 @@ package org.apache.spark.deploy -import java.io.{File, FileInputStream, IOException} -import java.util.Properties import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.SparkException import org.apache.spark.util.Utils /** @@ -63,9 +60,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St val defaultProperties = new HashMap[String, String]() if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - val file = new File(filename) - SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) => - if (k.startsWith("spark")) { + Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + if (k.startsWith("spark.")) { defaultProperties(k) = v if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } else { @@ -90,19 +86,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St */ private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user - if (propertiesFile == null) { - val sep = File.separator - val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") - val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - - confDir.foreach { sparkConfDir => - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } - } - } + propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) val properties = HashMap[String, String]() properties.putAll(defaultSparkProperties) @@ -397,23 +381,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.exitFn() } } - -object SparkSubmitArguments { - /** Load properties present in the given file. */ - def getPropertiesFromFile(file: File): Seq[(String, String)] = { - require(file.exists(), s"Properties file $file does not exist") - require(file.isFile(), s"Properties file $file is not a normal file") - val inputStream = new FileInputStream(file) - try { - val properties = new Properties() - properties.load(inputStream) - properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim)) - } catch { - case e: IOException => - val message = s"Failed when loading Spark properties file $file" - throw new SparkException(message, e) - } finally { - inputStream.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index a64170a47bc1c..0125330589da5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper { assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") // Parse the properties file for the equivalent spark.driver.* configs - val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val properties = Utils.getPropertiesFromFile(propertiesFile) val confDriverMemory = properties.get("spark.driver.memory") val confLibraryPath = properties.get("spark.driver.extraLibraryPath") val confClasspath = properties.get("spark.driver.extraClassPath") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 25fc76c23e0fb..5bce32a04d16d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,12 +18,14 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf +import org.apache.spark.util.Utils /** * Command-line parser for the master. */ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { private var logDir: String = null + private var propertiesFile: String = null parse(args.toList) @@ -32,11 +34,16 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case ("--dir" | "-d") :: value :: tail => logDir = value conf.set("spark.history.fs.logDirectory", value) + System.setProperty("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => printUsageAndExit(0) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + case Nil => case _ => @@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] } } + // This mutates the SparkConf, so all accesses to it must be made after this line + Utils.loadDefaultSparkProperties(conf, propertiesFile) + private def printUsageAndExit(exitCode: Int) { System.err.println( """ - |Usage: HistoryServer + |Usage: HistoryServer [options] + | + |Options: + | --properties-file FILE Path to a custom Spark properties file. + | Default is conf/spark-defaults.conf. | |Configuration options can be set by setting the corresponding JVM system property. |History Server options are always available; additional options depend on the provider. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 4b0dbbe543d3f..e34bee7854292 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_MASTER_HOST") != null) { @@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } + + parse(args.toList) + + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + if (conf.contains("spark.master.ui.port")) { webUiPort = conf.get("spark.master.ui.port").toInt } - parse(args.toList) - def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case Nil => {} @@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") + " --webui-port PORT Port for web UI (default: 8080)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 71650cd773bcf..71d7385b08eb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -111,13 +111,14 @@ private[spark] class ExecutorRunner( case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString + case "{{APP_ID}}" => appId case other => other } def getCommandSeq = { val command = Command( appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.arguments.map(substituteVariables), appDesc.command.environment, appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1e295aaa48c30..019cd70f2a229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { @@ -41,21 +42,27 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_WORKER_CORES") != null) { cores = System.getenv("SPARK_WORKER_CORES").toInt } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) + if (conf.getenv("SPARK_WORKER_MEMORY") != null) { + memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY")) } if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } - if (conf.contains("spark.worker.ui.port")) { - webUiPort = conf.get("spark.worker.ui.port").toInt - } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } parse(args.toList) + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } + + checkWorkerMemory() + def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -87,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case value :: tail => @@ -122,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") + " --webui-port PORT Port for web UI (default: 8081)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } @@ -153,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } + + def checkWorkerMemory(): Unit = { + if (memory <= 0) { + val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + throw new IllegalStateException(message) + } + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 06061edfc0844..c40a3e16675ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -152,6 +152,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { "Usage: CoarseGrainedExecutorBackend " + " [] ") System.exit(1) + + // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) + // and CoarseMesosSchedulerBackend (for mesos mode). case 5 => run(args(0), args(1), args(2), args(3).toInt, args(4), None) case x if x > 5 => diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 6b00190c5eccc..bda4bf50932c3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -117,7 +117,16 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.connect.threads.max", 8), conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) + Utils.namedThreadFactory("handle-connect-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleConnectExecutor is not handled properly", t) + } + } + + } private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation @@ -748,9 +757,7 @@ private[nio] class ConnectionManager( } catch { case e: Exception => { logError(s"Exception was thrown while processing message", e) - val m = Message.createBufferMessage(bufferMessage.id) - m.hasError = true - ackMessage = Some(m) + ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) } } finally { sendMessage(connectionManagerId, ackMessage.getOrElse { @@ -913,8 +920,12 @@ private[nio] class ConnectionManager( } case scala.util.Success(ackMessage) => if (ackMessage.hasError) { + val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head + val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) + errorMsgByteBuf.get(errorMsgBytes) + val errorMsg = new String(errorMsgBytes, "utf-8") val e = new IOException( - "sendMessageReliably failed with ACK that signalled a remote error") + s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") if (!promise.tryFailure(e)) { logWarning("Ignore error because promise is completed", e) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 0b874c2891255..3ad04591da658 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -84,6 +85,19 @@ private[nio] object Message { createBufferMessage(new Array[ByteBuffer](0), ackId) } + /** + * Create a "negative acknowledgment" to notify a sender that an error occurred + * while processing its message. The exception's stacktrace will be formatted + * as a string, serialized into a byte array, and sent as the message payload. + */ + def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { + val exceptionString = Utils.exceptionString(exception) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8")) + val errorMessage = createBufferMessage(serializedExceptionString, ackId) + errorMessage.hasError = true + errorMessage + } + def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { case BUFFER_MESSAGE => new BufferMessage(header.id, diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..5add4fc433fb3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } catch { case e: Exception => { logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + Some(Message.createErrorMessage(e, msg.id)) } } case otherMessage: Any => - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" + logError(errorMsg) + Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..ede5568493cc0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. + // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max(1, + (1.5 * num * partsScanned / results.size).toInt - partsScanned) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - results.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b63eb23e9ee1..8010dd90082f8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -196,7 +196,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0d97506450a7f..ac96de86dd6d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outfmt.newInstance @@ -1027,15 +1027,13 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.getStageId, context.getPartitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { - var count = 0 while (iter.hasNext) { val record = iter.next() - count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } } finally { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aba40d152e3e..71cabf61d4ee0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, // interpolate the number of partitions we need to try, but overestimate it by 50%. + // We also cap the estimation in the end. if (buf.size == 0) { numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 788eb1ff4e455..f81fa6d8089fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -633,14 +633,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, true) - TaskContext.setTaskContext(taskContext) + new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c6e47c84a0cb2..2552d03d18d06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, false) - TaskContext.setTaskContext(context) + context = new TaskContextImpl(stageId, partitionId, attemptId, false) + TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } @@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex var metrics: Option[TaskMetrics] = None // Task context, to be initialized in run(). - @transient protected var context: TaskContext = _ + @transient protected var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ed209d195ec9d..8c7de75600b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", + "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90828578cd88f..d7f88de4b40aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -150,17 +150,17 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( ("cd %s*; " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d") + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") .format(basename, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores)) + offer.getHostname, numCores, appId)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6a06257ed0c08..088f06e389d83 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -457,16 +457,18 @@ private[spark] class BlockManagerInfo( if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize if (originalLevel.useMemory) { - _remainingMem += memSize + _remainingMem += originalMemSize } } if (storageLevel.isValid) { /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * But the memSize here indicates the data size in or dropped from memory, + * The memSize here indicates the data size in or dropped from memory, * tachyonSize here indicates the data size in or dropped from Tachyon, * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. @@ -493,7 +495,6 @@ private[spark] class BlockManagerInfo( val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), Utils.bytesToString(_remainingMem))) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 2987dc04494a5..f0e43fbf70976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -71,19 +71,19 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} + {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - + {Utils.bytesToString(v.inputBytes)} - + {Utils.bytesToString(v.shuffleRead)} - + {Utils.bytesToString(v.shuffleWrite)} - + {Utils.bytesToString(v.memoryBytesSpilled)} - + {Utils.bytesToString(v.diskBytesSpilled)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2e67310594784..4ee7f08ab47a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -176,9 +176,9 @@ private[ui] class StageTableBase( {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} - {inputReadWithUnit} - {shuffleReadWithUnit} - {shuffleWriteWithUnit} + {inputReadWithUnit} + {shuffleReadWithUnit} + {shuffleWriteWithUnit} } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 716591c9ed449..83489ca0679ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -58,9 +58,9 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.numCachedPartitions} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.tachyonSize)} - {Utils.bytesToString(rdd.diskSize)} + {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.tachyonSize)} + {Utils.bytesToString(rdd.diskSize)} // scalastyle:on } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..f41c8d0315cb3 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) val akkaFailureDetector = conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6d1fc05a15d2c..fdc73f08261a6 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -51,12 +51,27 @@ private[spark] class FileLogger( def this( logDir: String, sparkConf: SparkConf, - compress: Boolean = false, - overwrite: Boolean = true) = { + compress: Boolean, + overwrite: Boolean) = { this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, overwrite = overwrite) } + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = true) + } + + def this( + logDir: String, + sparkConf: SparkConf) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false, + overwrite = true) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 07477dd460a4b..53a7512edd852 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -340,7 +340,7 @@ private[spark] object Utils extends Logging { val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) - uri.getScheme match { + Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) @@ -374,7 +374,7 @@ private[spark] object Utils extends Logging { } } Files.move(tempFile, targetFile) - case "file" | null => + case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) @@ -1368,16 +1368,17 @@ private[spark] object Utils extends Logging { if (uri.getPath == null) { throw new IllegalArgumentException(s"Given path is malformed: $uri") } - uri.getScheme match { - case windowsDrive(d) if windows => + + Option(uri.getScheme) match { + case Some(windowsDrive(d)) if windows => new URI("file:/" + uri.toString.stripPrefix("/")) - case null => + case None => // Preserve fragments for HDFS file name substitution (denoted by "#") // For instance, in "abc.py#xyz.py", "xyz.py" is the name observed by the application val fragment = uri.getFragment val part = new File(uri.getPath).toURI new URI(part.getScheme, part.getPath, fragment) - case _ => + case Some(other) => uri } } @@ -1399,15 +1400,64 @@ private[spark] object Utils extends Logging { } else { paths.split(",").filter { p => val formattedPath = if (windows) formatWindowsPath(p) else p - new URI(formattedPath).getScheme match { + val uri = new URI(formattedPath) + Option(uri.getScheme).getOrElse("file") match { case windowsDrive(d) if windows => false - case "local" | "file" | null => false + case "local" | "file" => false case _ => true } } } } + /** + * Load default Spark properties from the given file. If no file is provided, + * use the common defaults file. This mutates state in the given SparkConf and + * in this JVM's system properties if the config specified in the file is not + * already set. Return the path of the properties file used. + */ + def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = { + val path = Option(filePath).getOrElse(getDefaultPropertiesFile()) + Option(path).foreach { confFile => + getPropertiesFromFile(confFile).filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.setIfMissing(k, v) + sys.props.getOrElseUpdate(k, v) + } + } + path + } + + /** Load properties present in the given file. */ + def getPropertiesFromFile(filename: String): Map[String, String] = { + val file = new File(filename) + require(file.exists(), s"Properties file $file does not exist") + require(file.isFile(), s"Properties file $file is not a normal file") + + val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + try { + val properties = new Properties() + properties.load(inReader) + properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + } catch { + case e: IOException => + throw new SparkException(s"Failed when loading Spark properties from $filename", e) + } finally { + inReader.close() + } + } + + /** Return the path of the default Spark properties file. */ + def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = { + env.get("SPARK_CONF_DIR") + .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" }) + .map { t => new File(s"$t${File.separator}spark-defaults.conf")} + .filter(_.isFile) + .map(_.getAbsolutePath) + .orNull + } + /** Return a nice string representation of the exception, including the stack trace. */ def exceptionString(e: Exception): String = { if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4a078435447e5..b8fa822ae4bd8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 0944bf8cd5c71..e9ec700e32e15 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.getStageId(); - context.getPartitionId(); + context.stageId(); + context.partitionId(); context.isRunningLocally(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d735010d7c9d5..c0735f448d193 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala new file mode 100644 index 0000000000000..31edad1c56c73 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import org.apache.hadoop.io.BytesWritable + +class SparkContextSuite extends FunSuite { + //Regression test for SPARK-3121 + test("BytesWritable implicit conversion is correct") { + val bytesWritable = new BytesWritable() + val inputArray = (1 to 10).map(_.toByte).toArray + bytesWritable.set(inputArray, 0, 10) + bytesWritable.set(inputArray, 0, 5) + + val converter = SparkContext.bytesWritableConverter() + val byteArray = converter.convert(bytesWritable) + assert(byteArray.length === 5) + + bytesWritable.set(inputArray, 0, 0) + val byteArray2 = converter.convert(bytesWritable) + assert(byteArray2.length === 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 39ab53cf0b5b1..5e2592e8d2e8d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -26,14 +26,12 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { - def f(s:String) = new File(s) + val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val appId = "12345-worker321-9876" - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), - f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) - + Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", + new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala new file mode 100644 index 0000000000000..1a28a9a187cd7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.worker + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + + +class WorkerArgumentsTest extends FunSuite { + + test("Memory can't be set to 0 when cmd line args leave off M or G") { + val conf = new SparkConf + val args = Array("-m", "10000", "spark://localhost:0000 ") + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + + test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "50000" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "5G" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 5120) + } + + test("Memory correctly set from args with M appended to memory value") { + val conf = new SparkConf + val args = Array("-m", "10000M", "spark://localhost:0000 ") + + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 10000) + + } + +} diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 9f49587cdc670..b70734dfe37cf 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import org.scalatest.FunSuite import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. @@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite { val manager = new ConnectionManager(0, conf, securityManager) val managerServer = new ConnectionManager(0, conf, securityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception + throw new Exception("Custom exception text") }) val size = 10 * 1024 * 1024 @@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - intercept[IOException] { + val exception = intercept[IOException] { Await.result(future, 1 second) } + assert(Utils.exceptionString(exception).contains("Custom exception text")) manager.stop() managerServer.stop() diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index be972c5e97a7e..271a90c6646bb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index faba5508c906c..561a5e9cd90c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd70929656..a8c049d749015 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.mockito.Mockito._ @@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { (bmId, Seq((blId1, 1L), (blId2, 1L)))) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index dc2a05631d83d..72466a3aa1130 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -74,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { test("Logging when directory already exists") { // Create the logging directory multiple times - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() // If overwrite is not enabled, an exception should be thrown intercept[IOException] { - new FileLogger(logDirPathString, new SparkConf, overwrite = false).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start() } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 0344da60dae66..ea7ef0524d1e1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -27,6 +27,8 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UtilsSuite extends FunSuite { test("bytesToString") { @@ -332,4 +334,21 @@ class UtilsSuite extends FunSuite { assert(!tempFile2.exists()) } + test("loading properties from file") { + val outFile = File.createTempFile("test-load-spark-properties", "test") + try { + System.setProperty("spark.test.fileNameLoadB", "2") + Files.write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8) + val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) + properties + .filter { case (k, v) => k.startsWith("spark.")} + .foreach { case (k, v) => sys.props.getOrElseUpdate(k, v)} + val sparkConf = new SparkConf + assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) + assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) + } finally { + outFile.delete() + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index f311f0d2a6206..f0204c640bc89 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -161,14 +161,6 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment - - - - - @@ -365,7 +357,7 @@ Apart from these, the following properties are also available, and may be useful @@ -725,7 +717,7 @@ Apart from these, the following properties are also available, and may be useful - + @@ -893,7 +885,7 @@ Apart from these, the following properties are also available, and may be useful to wait for before scheduling begins. Specified as a double between 0 and 1. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredResourcesWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d10bd63746629..7978e934fb36b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: {% highlight java %} @@ -113,12 +113,6 @@ public class KMeansExample { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
@@ -153,3 +147,9 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
+ +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index d5c539db791be..2094963392295 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -184,12 +184,6 @@ public class CollaborativeFiltering { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
@@ -229,6 +223,12 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
+In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 21cb35b4270ca..870fed6cc5024 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -121,9 +121,9 @@ public class SVD { The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -200,10 +200,11 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index d31bec3e1bd01..bc914a1899801 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -323,9 +323,9 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -482,12 +482,6 @@ public class LinearRegression { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
@@ -519,6 +513,12 @@ print("Mean Squared Error = " + str(MSE))
+In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, diff --git a/docs/monitoring.md b/docs/monitoring.md index d07ec4a57a2cc..e3f81a76acdbb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -77,6 +77,13 @@ follows: one implementation, provided by Spark, which looks for application logs stored in the file system. + + + + + diff --git a/docs/quick-start.md b/docs/quick-start.md index 23313d8aa6152..6236de0e1f2c4 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -8,7 +8,7 @@ title: Quick Start This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), -then show how to write standalone applications in Java, Scala, and Python. +then show how to write applications in Java, Scala, and Python. See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the @@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia -# Standalone Applications -Now say we wanted to write a standalone application using the Spark API. We will walk through a +# Self-Contained Applications +Now say we wanted to write a self-contained application using the Spark API. We will walk through a simple application in both Scala (with SBT), Java (with Maven), and Python.
@@ -387,7 +387,7 @@ Lines with a: 46, Lines with b: 23
-Now we will show how to write a standalone application using the Python API (PySpark). +Now we will show how to write an application using the Python API (PySpark). As an example, we'll create a simple Spark application, `SimpleApp.py`: diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5c21e912ea160..738309c668387 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -494,7 +494,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index eefa022f1927c..d2c5ca48c6cb8 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -48,7 +48,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json" + path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") # Create a SchemaRDD from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py new file mode 100644 index 0000000000000..40faff0ccc7db --- /dev/null +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: hdfs_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingHDFSWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cfa9c1ff5bfbc --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..18a9a5a452ffb --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c4317a6aec798..45527d9382fd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,17 +46,6 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - def pickPartitioner(v: String): PartitionStrategy = { - // TODO: Use reflection rather than listing all the partitioning strategies here. - v match { - case "RandomVertexCut" => RandomVertexCut - case "EdgePartition1D" => EdgePartition1D - case "EdgePartition2D" => EdgePartition2D - case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut - case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v) - } - } - val conf = new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") @@ -67,7 +56,7 @@ object Analytics extends Logging { sys.exit(1) } val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy") - .map(pickPartitioner(_)) + .map(PartitionStrategy.fromString(_)) val edgeStorageLevel = options.remove("edgeStorageLevel") .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY) val vertexStorageLevel = options.remove("vertexStorageLevel") @@ -107,7 +96,7 @@ object Analytics extends Logging { if (!outFname.isEmpty) { logWarning("Saving pageranks of pages to " + outFname) - pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname) + pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } sc.stop() @@ -129,7 +118,7 @@ object Analytics extends Logging { val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct()) + println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) sc.stop() case "triangles" => @@ -147,7 +136,7 @@ object Analytics extends Logging { minEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) - // TriangleCount requires the graph to be partitioned + // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) println("Triangles: " + triangles.vertices.map { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 837d0591478c5..0890e6263e165 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -189,9 +189,10 @@ object DecisionTreeRunner { // Create training, test sets. val splits = if (params.testInput != "") { // Load testInput. + val numFeatures = examples.take(1)(0).features.size val origTestExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) } params.algo match { case Classification => { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index e26f213e8afa8..0c52ef8ed96ac 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -28,7 +28,7 @@ object HiveFromSpark { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) - // A local hive context creates an instance of the Hive Metastore in process, storing the + // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. val hiveContext = new HiveContext(sc) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 33235d150b4a5..13943ed5442b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,103 +17,141 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer import java.nio.charset.Charset +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase} -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.handler.codec.compression._ +class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") + + var ssc: StreamingContext = null + var transceiver: NettyTransceiver = null -class FlumeStreamSuite extends TestSuiteBase { + after { + if (ssc != null) { + ssc.stop() + } + if (transceiver != null) { + transceiver.close() + } + } test("flume input stream") { - runFlumeStreamTest(false) + testFlumeStream(testCompression = false) } test("flume input compressed stream") { - runFlumeStreamTest(true) + testFlumeStream(testCompression = true) + } + + /** Run test on flume stream */ + private def testFlumeStream(testCompression: Boolean): Unit = { + val input = (1 to 100).map { _.toString } + val testPort = findFreePort() + val outputBuffer = startContext(testPort, testCompression) + writeAndVerify(input, testPort, outputBuffer, testCompression) + } + + /** Find a free port */ + private def findFreePort(): Int = { + Utils.startServiceOnPort(23456, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + })._2 } - - def runFlumeStreamTest(enableDecompression: Boolean) { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val (flumeStream, testPort) = - Utils.startServiceOnPort(9997, (trialPort: Int) => { - val dstream = FlumeUtils.createStream( - ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) - (dstream, trialPort) - }) + /** Setup and start the streaming context */ + private def startContext( + testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + ssc = new StreamingContext(conf, Milliseconds(200)) + val flumeStream = FlumeUtils.createStream( + ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() + outputBuffer + } - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - var client: AvroSourceProtocol = null - - if (enableDecompression) { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], - new NettyTransceiver(new InetSocketAddress("localhost", testPort), - new CompressionChannelFactory(6))) - } else { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver) - } + /** Send data to the flume receiver and verify whether the data was received */ + private def writeAndVerify( + input: Seq[String], + testPort: Int, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], + enableCompression: Boolean + ) { + val testAddress = new InetSocketAddress("localhost", testPort) - for (i <- 0 until input.size) { + val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - client.append(event) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + event } - Thread.sleep(1000) - - val startTime = System.currentTimeMillis() - while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) - Thread.sleep(100) + eventually(timeout(10 seconds), interval(100 milliseconds)) { + // if last attempted transceiver had succeeded, close it + if (transceiver != null) { + transceiver.close() + transceiver = null + } + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + client should not be null + + // Send data + val status = client.appendBatch(inputEvents.toList) + status should be (avro.Status.OK) } - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val decoder = Charset.forName("UTF-8").newDecoder() - - assert(outputBuffer.size === input.length) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - val str = decoder.decode(outputBuffer(i).head.event.getBody) - assert(str.toString === input(i).toString) - assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + + val decoder = Charset.forName("UTF-8").newDecoder() + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + output should be (input) } } - class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f7251e65e04f1..9a100170b75c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.util.{ArrayList => JArrayList} import scala.collection.JavaConverters._ import scala.language.existentials @@ -27,6 +28,7 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature.Word2Vec @@ -639,13 +641,24 @@ private[spark] object SerDe extends Serializable { } } + var initialized = false + // This should be called before trying to serialize any above classes + // In cluster mode, this should be put in the closure def initialize(): Unit = { - new DenseVectorPickler().register() - new DenseMatrixPickler().register() - new SparseVectorPickler().register() - new LabeledPointPickler().register() - new RatingPickler().register() + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseVectorPickler().register() + new LabeledPointPickler().register() + new RatingPickler().register() + initialized = true + } + } } + // will not called in Executor automatically + initialize() def dumps(obj: AnyRef): Array[Byte] = { new Pickler().dumps(obj) @@ -659,4 +672,33 @@ private[spark] object SerDe extends Serializable { def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + new PythonRDD.AutoBatchedPickler(iter) + } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8380058cf9b41..ec2d481dccc22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -111,7 +111,10 @@ class RowMatrix( */ def computeGramianMatrix(): Matrix = { val n = numCols().toInt - val nt: Int = n * (n + 1) / 2 + checkNumColumns(n) + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( @@ -123,6 +126,16 @@ class RowMatrix( RowMatrix.triuToFull(n, GU.data) } + private def checkNumColumns(cols: Int): Unit = { + if (cols > 65535) { + throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") + } + if (cols > 10000) { + val mem = cols * cols * 8 + logWarning(s"$cols columns will require at least $mem bytes of memory!") + } + } + /** * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k @@ -301,12 +314,7 @@ class RowMatrix( */ def computeCovariance(): Matrix = { val n = numCols().toInt - - if (n > 10000) { - val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE - logWarning(s"The number of columns $n is greater than 10000! " + - s"We need at least $mem bytes of memory.") - } + checkNumColumns(n) val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index fa7a26f17c3ca..ebbd8e0257209 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -176,6 +176,8 @@ private class RandomForest ( timer.stop("findBestSplits") } + baggedInput.unpersist() + timer.stop("total") logInfo("Internal timing for DecisionTree:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 55f422dff0d71..ce8825cc03229 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator( numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } - /** - * Indicator for each feature of whether that feature is an unordered feature. - * TODO: Is Array[Boolean] any faster? - */ - def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** * Total number of elements stored in this aggregator */ @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator( * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getFeatureOffset(featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - featureOffsets(featureIndex) - } + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") val baseOffset = featureOffsets(featureIndex) (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 49adc7e6b7e56..5bc0f2635c6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -93,7 +94,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata { +private[tree] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -114,6 +115,10 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index d8476b5cd7bc7..004838ee5ba0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,12 +17,15 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.annotation.DeveloperApi + /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -private[tree] class Predict( +@DeveloperApi +class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 4d66d6d81caa5..6a22e2abe59bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext */ override def toString: String = algo match { case Classification => - s"RandomForestModel classifier with $numTrees trees" + s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" case Regression => - s"RandomForestModel regressor with $numTrees trees" + s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" case _ => throw new IllegalArgumentException( s"RandomForestModel given unknown algo parameter: $algo.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index a06153f22d2c6..7e88e01ace00d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -174,6 +174,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + RandomForestSuite.validateClassifier(model, arr, 1.0) + } + } object RandomForestSuite { diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 39f8ba4745737..d919b18e09855 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -32,7 +32,7 @@ object MimaBuild { ProblemFilters.exclude[MissingMethodProblem](fullName), // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated - // bytecode. It is not possible to exhustively list everything. + // bytecode. It is not possible to exhaustively list everything. // But this should be okay. ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d499302124461..350aad47735e4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,11 @@ object MimaExcludes { "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus") + "org.apache.spark.scheduler.MapStatus"), + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext") + ) case v if v.startsWith("1.1") => diff --git a/project/plugins.sbt b/project/plugins.sbt index 8096c61414660..678f5ed1ba610 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -17,7 +17,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala deleted file mode 100644 index 80d3faa3fe749..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} -import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token} -import scalariform.parser.CompilationUnit - -class SparkSpaceAfterCommentStartChecker extends ScalariformChecker { - val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end" - - private def multiLineCommentRegex(comment: Token) = - Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def scalaDocPatternRegex(comment: Token) = - Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def singleLineCommentRegex(comment: Token): Boolean = - comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""") - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens - .filter(hasComment) - .map { - _.associatedWhitespaceAndComments.comments.map { - case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset) - case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset) - case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset) - case _ => None - }.flatten - }.flatten.map(PositionError(_)) - } - - - private def hasComment(x: Token) = - x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty - -} diff --git a/python/.gitignore b/python/.gitignore index 80b361ffbd51c..52128cf844a79 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,5 +1,5 @@ *.pyc -docs/ +docs/_build/ pyspark.egg-info build/ dist/ diff --git a/python/docs/conf.py b/python/docs/conf.py index 8e6324f058251..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..19fefbfc057a4 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -5,7 +5,7 @@ (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index d66e051b15371..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib diff --git a/python/docs/make.bat b/python/docs/make.bat index adad44fd7536a..c011e82b4a35a 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,242 +1,6 @@ @ECHO OFF -REM Command file for Sphinx documentation +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat new file mode 100644 index 0000000000000..7bcaeafad13d7 --- /dev/null +++ b/python/docs/make2.bat @@ -0,0 +1,243 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 85c04624da4a6..8d27ccb95f82c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -68,7 +68,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None): + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -104,14 +104,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -154,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -215,8 +215,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile - SparkContext._jvm.SerDeUtil.initialize() - SparkContext._jvm.SerDe.initialize() if instance: if (SparkContext._active_spark_context and diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cd43982191702..e295c9d0954d9 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,7 +21,7 @@ from numpy import array from pyspark import SparkContext, PickleSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -244,7 +244,7 @@ def train(cls, data, lambda_=1.0): :param lambda_: The smoothing parameter """ sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) + jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_) labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12c56022717a5..5ee7997104d21 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,7 +17,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['KMeansModel', 'KMeans'] @@ -85,7 +85,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" # cache serialized data to avoid objects over head in JVM cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) + _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) centers = ser.loads(str(bytes)) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index a44a27fd3b6a6..b5a3f22c6907e 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -19,8 +19,7 @@ Python package for feature in MLlib. """ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd __all__ = ['Word2Vec', 'Word2VecModel'] @@ -44,6 +43,7 @@ def transform(self, word): """ :param word: a word :return: vector representation of word + Transforms a word to its vector representation Note: local use only @@ -57,6 +57,7 @@ def findSynonyms(self, x, num): :param x: a word or a vector representation of word :param num: number of synonyms to find :return: array of (word, cosineSimilarity) + Find synonyms of a word Note: local use only @@ -174,7 +175,7 @@ def fit(self, data): seed = self.seed model = sc._jvm.PythonMLLibAPI().trainWord2Vec( - data._to_java_object_rdd(), vectorSize, + _to_java_object_rdd(data), vectorSize, learningRate, numPartitions, numIterations, seed) return Word2VecModel(sc, model) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 24c5480b2f753..773d8d393805d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,8 @@ import numpy as np +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -50,6 +52,17 @@ def fast_pickle_array(ar): _have_scipy = False +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + def _convert_to_vector(l): if isinstance(l, Vector): return l diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index a787e4dea2c55..73baba4ace5f6 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -32,7 +32,7 @@ def serialize(f): @wraps(f) def func(sc, *a, **kw): jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc, + return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc, BatchedSerializer(PickleSerializer(), 1024)) return func diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 59c1c5ff0ced0..17f96b8700bd7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,6 +18,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MatrixFactorizationModel', 'ALS'] @@ -77,9 +78,9 @@ def predictAll(self, user_product): first = tuple(map(int, first)) assert all(type(x) is int for x in first), "user and product in user_product shoul be int" sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd()) + tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd()) jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc, + return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, AutoBatchedSerializer(PickleSerializer())) @@ -97,7 +98,7 @@ def _prepare(cls, ratings): # serialize them by AutoBatchedSerializer before cache to reduce the # objects overhead in JVM cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return cached._to_java_object_rdd() + return _to_java_object_rdd(cached) @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 12b322aaae796..93e17faf5cd51 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,8 +19,8 @@ from numpy import array from pyspark import SparkContext -from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -131,7 +131,7 @@ def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights) # use AutoBatchedSerializer before cache to reduce the memory # overhead in JVM cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(cached._to_java_object_rdd(), initial_bytes) + ans = train_func(_to_java_object_rdd(cached), initial_bytes) assert len(ans) == 2, "JVM call result had unexpected length" weights = ser.loads(str(ans[0])) return modelClass(weights, ans[1]) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index b9de0909a6fb1..a6019dadf781c 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ from functools import wraps from pyspark import PickleSerializer +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -106,7 +107,7 @@ def colStats(rdd): array([ 2., 0., 0., -2.]) """ sc = rdd.ctx - jrdd = rdd._to_java_object_rdd() + jrdd = _to_java_object_rdd(rdd) cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @@ -162,14 +163,14 @@ def corr(x, y=None, method=None): if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = x._to_java_object_rdd() + jx = _to_java_object_rdd(x) if not y: resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) bytes = sc._jvm.SerDe.dumps(resultMat) ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jy = y._to_java_object_rdd() + jy = _to_java_object_rdd(y) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5c20e100e144f..463faf7b6f520 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,7 +25,11 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2f42dc104b327..64ee79d83e849 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ from pyspark import SparkContext, RDD from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector +from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] @@ -61,8 +61,8 @@ def predict(self, x): return self._sc.parallelize([]) if not isinstance(first[0], Vector): x = x.map(_convert_to_vector) - jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() - jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) + jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() + jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) else: @@ -104,7 +104,7 @@ def _train(data, type, numClasses, categoricalFeaturesInfo, first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" sc = data.context - jrdd = data._to_java_object_rdd() + jrdd = _to_java_object_rdd(data) cfiMap = MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 1357fd4fbc8aa..84b39a48619d2 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -19,7 +19,7 @@ import warnings from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -174,8 +174,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd) - return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer())) + jpyrdd = sc._jvm.SerDe.javaToPython(jrdd) + return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer())) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6797d50659a92..15be4bfec92f9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1070,10 +1070,13 @@ def take(self, num): # If we didn't find any rows after the previous iteration, # quadruple and retry. Otherwise, interpolate the number of # partitions we need to try, but overestimate it by 50%. + # We also cap the estimation in the end. if len(items) == 0: numPartsToTry = partsScanned * 4 else: - numPartsToTry = int(1.5 * num * partsScanned / len(items)) + # the first paramter of max is >=1 whenever partsScanned >= 2 + numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned + numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) left = num - len(items) @@ -2009,7 +2012,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - :param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3d1a34b281acc..08a0f0d8ffb3e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d3d36eb995ab6..b31a82f9b19ac 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -19,14 +19,14 @@ public classes of Spark SQL: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for SQL functionality. - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. - L{Row} - A Row of data returned by a Spark SQL query. + A Row of data returned by a Spark SQL query. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive.. """ import itertools diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..d2644a1d4ffab --- /dev/null +++ b/python/pyspark/streaming/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.dstream import DStream + +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..dc9dc41121935 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys + +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import, JavaObject + +from pyspark import RDD, SparkConf +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer + +__all__ = ["StreamingContext"] + + +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. + + Also, it will update the port number (0) with real port + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream} various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. + """ + _transformerSerializer = None + + def __init__(self, sparkContext, batchDuration=None, jssc=None): + """ + Create a new StreamingContext. + + @param sparkContext: L{SparkContext} object. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches + """ + + self._sc = sparkContext + self._jvm = self._sc._jvm + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) + + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + + # register serializer for TransformFunction + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + + @classmethod + def getOrCreate(cls, checkpointPath, setupFunc): + """ + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. + + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams + """ + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + ssc = setupFunc() + ssc.checkpoint(checkpointPath) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(checkpointPath) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in seconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(int(timeout * 1000)) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing + of all received data to be completed + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs (if the developer wishes to query old data outside the + DStream computation). + + @param duration: Minimum duration (in seconds) that each DStream + should remember its RDDs + """ + self._jssc.remember(self._jduration(duration)) + + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory: HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited + lines. + + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects + """ + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def _check_serializers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i]._reserialize() + + def queueStream(self, rdds, oneAtATime=True, default=None): + """ + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds + """ + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] + self._check_serializers(rdds) + + jrdds = ListConverter().convert([r._jrdd for r in rdds], + SparkContext._gateway._gateway_client) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. + """ + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..5ae5cf07f0137 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,621 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain, ifilter, imap +import operator +import time +from datetime import datetime + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.util import rddToFileName, TransformFunction +from pyspark.rdd import portable_hash +from pyspark.resultiterable import ResultIterable + +__all__ = ["DStream"] + + +class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self._sc = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func, True) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. + """ + return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) + + def reduce(self, func): + """ + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Return a new DStream by applying reduceByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Return a new DStream by applying combineByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def func(rdd): + return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) + return self.transform(func) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. + """ + return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def pprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. + """ + def takeAndPrint(time, rdd): + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % time + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Return a new DStream by applying a map function to the value of + each key-value pairs in this DStream without changing the key. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in this DStream without changing the key. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Persist the RDDs of this DStream with the given storage level + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Enable periodic checkpointing of RDDs of this DStream + + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed + """ + self.is_checkpointed = True + self._jdstream.checkpoint(self._ssc._jduration(interval)) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream by applying groupByKey on each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) + + def countByValue(self): + """ + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. + """ + return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save each RDD in this DStream as at text file, using string + representation of elements. + """ + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise + return self.foreachRDD(saveAsTextFile) + + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.func_code.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func) + + def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) + """ + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer + return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) + + def repartition(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. + """ + return self.transform(lambda rdd: rdd.repartition(numPartitions)) + + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + + @param other: Another DStream having the same interval (i.e., slideDuration) + as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") + return self.transformWith(lambda a, b: a.union(b), other, True) + + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of this + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) + + def _jtime(self, timestamp): + """ Convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self._sc._jvm.Time(long(timestamp * 1000)) + + def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] + + def _validate_window_param(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._validate_window_param(windowDuration, slideDuration) + d = self._ssc._jduration(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = self._ssc._jduration(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + keyed = self.map(lambda x: (1, x)) + reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, + windowDuration, slideDuration, 1) + return reduced.map(lambda (k, v): v) + + def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ + return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, + windowDuration, slideDuration) + + def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + """ + keyed = self.map(lambda x: (x, 1)) + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, + windowDuration, slideDuration, numPartitions) + return counted.filter(lambda (k, v): v > 0).count() + + def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: Number of partitions of each RDD in the new DStream. + """ + ls = self.mapValues(lambda x: [x]) + grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], + windowDuration, slideDuration, numPartitions) + return grouped.mapValues(ResultIterable) + + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) + + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def reduceFunc(t, a, b): + if a is None: + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) + else: + g = a.cogroup(b, numPartitions) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) + return state.filter(lambda (k, v): v is not None) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + +class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ + def __init__(self, prev, func): + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer + self.is_cached = False + self.is_checkpointed = False + self._jdstream_val = None + + if (isinstance(prev, TransformedDStream) and + not prev.is_cached and not prev.is_checkpointed): + prev_func = prev.func + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..a8d876d0fa3b3 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..86ee5aa04f252 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..f5ccf31abb3fa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,7 +34,11 @@ from platform import python_implementation if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest @@ -679,6 +683,12 @@ def test_udf(self): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(u"4", res[0]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/python/run-tests b/python/run-tests index f6a96841175e8..80acd002ab7eb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -81,6 +81,12 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_streaming_tests() { + echo "Run streaming tests ..." + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -96,6 +102,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -105,6 +112,7 @@ if [ $(which pypy) ]; then run_core_tests run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then diff --git a/scalastyle-config.xml b/scalastyle-config.xml index c54f8b72ebf42..0ff521706c71a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -141,5 +141,5 @@ - + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b3ae8e6779700..3d4296f9d7068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -77,8 +77,9 @@ object ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fe83eb12502dc..82553063145b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution), + CheckResolution, + CheckAggregation), Batch("AnalysisOperators", fixedPoint, EliminateAnalysisOperators) ) @@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Checks for non-aggregated attributes with aggregation + */ + object CheckAggregation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan.transform { + case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => + def isValidAggregateExpression(expr: Expression): Boolean = expr match { + case _: AggregateExpression => true + case e: Attribute => groupingExprs.contains(e) + case e if groupingExprs.contains(e) => true + case e if e.references.isEmpty => true + case e => e.children.forall(isValidAggregateExpression) + } + + aggregateExprs.foreach { e => + if (!isValidAggregateExpression(e)) { + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") + } + } + + aggregatePlan + } + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs - + Project(aggregate.output, Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) } - } - + protected def containsAggregate(condition: Expression): Boolean = condition .collect { case ae: AggregateExpression => ae } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 64881854df7a5..7c480de107e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -220,20 +220,39 @@ trait HiveTypeCoercion { case a: BinaryArithmetic if a.right.dataType == StringType => a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + // we should cast all timestamp/date/string compare into string compare + case p: BinaryPredicate if p.left.dataType == StringType + && p.right.dataType == DateType => + p.makeCopy(Array(p.left, Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == StringType => + p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) + p.makeCopy(Array(Cast(p.left, StringType), p.right)) + case p: BinaryPredicate if p.left.dataType == TimestampType + && p.right.dataType == DateType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(a,b.map(Cast(_,TimestampType)))) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e) if e.dataType == StringType => Sum(Cast(e, DoubleType)) @@ -283,6 +302,8 @@ trait HiveTypeCoercion { // Skip if the type is boolean type already. Note that this extra cast should be removed // by optimizer.SimplifyCasts. case Cast(e, BooleanType) if e.dataType == BooleanType => e + // DateType should be null if be cast to boolean. + case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 67570a6f73c36..77d84e1687e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -88,7 +88,7 @@ case class Star( mapFunction: Attribute => Expression = identity[Attribute]) extends Attribute with trees.LeafNode[Expression] { - override def name = throw new UnresolvedException(this, "exprId") + override def name = throw new UnresolvedException(this, "name") override def exprId = throw new UnresolvedException(this, "exprId") override def dataType = throw new UnresolvedException(this, "dataType") override def nullable = throw new UnresolvedException(this, "nullable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index deb622c39faf5..75b6e37c2a1f9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.language.implicitConversions @@ -119,6 +119,7 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def decimalToLiteral(d: BigDecimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -174,6 +175,9 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index c3a08bbdb6bc7..2b4969b7cfec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,19 +17,26 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.Star + protected class AttributeEquals(val a: Attribute) { override def hashCode() = a.exprId.hashCode() - override def equals(other: Any) = other match { - case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId - case otherAttribute => false + override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId + case (a1, a2) => a1 == a2 } } object AttributeSet { - /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ - def apply(baseSet: Seq[Attribute]) = { - new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) - } + def apply(a: Attribute) = + new AttributeSet(Set(new AttributeEquals(a))) + + /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ + def apply(baseSet: Seq[Expression]) = + new AttributeSet( + baseSet + .flatMap(_.references) + .map(new AttributeEquals(_)).toSet) } /** @@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq + + override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f626d09f037bc..8e5ee12e314bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { override def foldable = child.foldable override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true + case (StringType, DateType) => true case _ => child.nullable } @@ -42,6 +45,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // UDFToString private[this] def castToString: Any => Any = child.dataType match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case DateType => buildCast[Date](_, dateToString) case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -56,7 +60,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => buildCast[String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + case DateType => + // Hive would return null when cast from date to boolean + buildCast[Date](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -95,6 +102,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { buildCast[Short](_, s => new Timestamp(s)) case ByteType => buildCast[Byte](_, b => new Timestamp(b)) + case DateType => + buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp case DecimalType => buildCast[BigDecimal](_, d => decimalToTimestamp(d)) @@ -130,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // Converts Timestamp to string according to Hive TimestampWritable convention private[this] def timestampToString(ts: Timestamp): String = { val timestampString = ts.toString - val formatted = Cast.threadLocalDateFormat.get.format(ts) + val formatted = Cast.threadLocalTimestampFormat.get.format(ts) if (timestampString.length > 19 && timestampString.substring(19) != ".0") { formatted + timestampString.substring(19) @@ -139,6 +148,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } } + // Converts Timestamp to string according to Hive TimestampWritable convention + private[this] def timestampToDateString(ts: Timestamp): String = { + Cast.threadLocalDateFormat.get.format(ts) + } + + // DateConverter + private[this] def castToDate: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => + try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null } + ) + case TimestampType => + // throw valid precision more than seconds, according to Hive. + // Timestamp.nanos is in 0 to 999,999,999, no more than a second. + buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + // Hive throws this exception as a Semantic Exception + // It is never possible to compare result when hive return with exception, so we can return null + // NULL is more reasonable here, since the query itself obeys the grammar. + case _ => _ => null + } + + // Date cannot be cast to long, according to hive + private[this] def dateToLong(d: Date) = null + + // Date cannot be cast to double, according to hive + private[this] def dateToDouble(d: Date) = null + + // Converts Date to string according to Hive DateWritable convention + private[this] def dateToString(d: Date): String = { + Cast.threadLocalDateFormat.get.format(d) + } + + // LongConverter private[this] def castToLong: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toLong catch { @@ -146,6 +188,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case DecimalType => @@ -154,6 +198,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } + // IntConverter private[this] def castToInt: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toInt catch { @@ -161,6 +206,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case DecimalType => @@ -169,6 +216,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } + // ShortConverter private[this] def castToShort: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toShort catch { @@ -176,6 +224,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case DecimalType => @@ -184,6 +234,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } + // ByteConverter private[this] def castToByte: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toByte catch { @@ -191,6 +242,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case DecimalType => @@ -199,6 +252,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } + // DecimalConverter private[this] def castToDecimal: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try BigDecimal(s.toDouble) catch { @@ -206,6 +260,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) @@ -213,6 +269,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) } + // DoubleConverter private[this] def castToDouble: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toDouble catch { @@ -220,6 +277,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case DecimalType => @@ -228,6 +287,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } + // FloatConverter private[this] def castToFloat: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toFloat catch { @@ -235,6 +295,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case DecimalType => @@ -245,17 +307,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] lazy val cast: Any => Any = dataType match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString - case BinaryType => castToBinary - case DecimalType => castToDecimal + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case DateType => castToDate case TimestampType => castToTimestamp - case BooleanType => castToBoolean - case ByteType => castToByte - case ShortType => castToShort - case IntegerType => castToInt - case FloatType => castToFloat - case LongType => castToLong - case DoubleType => castToDouble + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble } override def eval(input: Row): Any = { @@ -267,6 +330,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue() = { + new SimpleDateFormat("yyyy-MM-dd") + } + } + + // `SimpleDateFormat` is not thread-safe. + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue() = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 204904ecf04db..e7e81a21fdf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -39,6 +39,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } new GenericRow(outputArray) } + + override def toString = s"Row => [${exprArray.mkString(",")}]" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 78a0c55e4bbe5..ba240233cae61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ @@ -33,6 +33,7 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(d, DecimalType) case t: Timestamp => Literal(t, TimestampType) + case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e5a958d599393..d023db44d8543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -57,6 +57,8 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => + override def references = AttributeSet(this) + def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute @@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def references = AttributeSet(this :: Nil) - override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index af9e4d86e995a..dcbbb62c0aca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -31,6 +31,25 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ def outputSet: AttributeSet = AttributeSet(output) + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + + /** + * The set of all attributes that are input to this operator by its children. + */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** + * Attributes that are referenced by expressions but not provided by this nodes children. + * Subclasses should override this method if they produce attributes internally as it is used by + * assertions designed to prevent the construction of invalid plans. + */ + def missingInput: AttributeSet = references -- inputSet + /** * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, @@ -132,4 +151,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) + + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f8ad8a7e0223..882e9c6110089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -53,12 +53,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) } - /** - * Returns the set of attributes that this node takes as - * input from its children. - */ - lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) - /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan @@ -68,6 +62,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + override protected def statePrefix = if (!resolved) "'" else super.statePrefix + /** * Returns true if all its children of this query plan have been resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f8e9930ac270d..14b03c7445c13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -138,11 +138,6 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - /** The set of all AttributeReferences required for this aggregation. */ - def references = - AttributeSet( - groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) - override def output = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 1d375b8754182..0cf139ebde417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.types -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag @@ -250,6 +250,16 @@ case object TimestampType extends NativeType { } } +case object DateType extends NativeType { + private[sql] type JvmType = Date + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Date, y: Date) = x.compareTo(y) + } +} + abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a @@ -349,7 +359,6 @@ case object FloatType extends FractionalType { object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) - def typeName: String = "array" } /** @@ -395,8 +404,6 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - - def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -459,8 +466,6 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) - - def simpleName = "map" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 692ed78a7292c..6dc5942023f9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet @@ -252,8 +252,11 @@ class ExpressionEvaluationSuite extends FunSuite { test("data type casting") { - val sts = "1970-01-01 00:00:01.1" - val ts = Timestamp.valueOf(sts) + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") checkEvaluation("abdef" cast DecimalType, null) @@ -266,8 +269,15 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + // all convert to string type to check + checkEvaluation( + Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) + checkEvaluation( + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -316,6 +326,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } + test("date") { + val d1 = Date.valueOf("1970-01-01") + val d2 = Date.valueOf("1970-01-02") + checkEvaluation(Literal(d1) < Literal(d2), true) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -323,6 +339,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(ts1) < Literal(ts2), true) } + test("date casting") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(Cast(d, ShortType), null) + checkEvaluation(Cast(d, IntegerType), null) + checkEvaluation(Cast(d, LongType), null) + checkEvaluation(Cast(d, FloatType), null) + checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, StringType), "1970-01-01") + checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37b4c8ffcba0b..37e88d72b9172 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -44,6 +44,11 @@ public abstract class DataType { */ public static final BooleanType BooleanType = new BooleanType(); + /** + * Gets the DateType object. + */ + public static final DateType DateType = new DateType(); + /** * Gets the TimestampType object. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java new file mode 100644 index 0000000000000..6677793baa365 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.sql.Date values. + * + * {@code DateType} is represented by the singleton object {@link DataType#DateType}. + */ +public class DateType extends DataType { + protected DateType() {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index c9faf0852142a..538dd5b734664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -92,6 +92,9 @@ private[sql] class FloatColumnAccessor(buffer: ByteBuffer) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, TIMESTAMP) @@ -118,6 +121,7 @@ private[sql] object ColumnAccessor { case BYTE.typeId => new ByteColumnAccessor(dup) case SHORT.typeId => new ShortColumnAccessor(dup) case STRING.typeId => new StringColumnAccessor(dup) + case DATE.typeId => new DateColumnAccessor(dup) case TIMESTAMP.typeId => new TimestampColumnAccessor(dup) case BINARY.typeId => new BinaryColumnAccessor(dup) case GENERIC.typeId => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 2e61a981375aa..300cef15bf8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -107,6 +107,8 @@ private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColum private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) + private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) @@ -151,6 +153,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case DATE.typeId => new DateColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 203a714e03c97..b34ab255d084a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -190,6 +190,24 @@ private[sql] class StringColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount) } +private[sql] class DateColumnStats extends ColumnStats { + var upper: Date = null + var lower: Date = null + var nullCount = 0 + + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Date] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 + } + } + + def collectedStatistics = Row(lower, upper, nullCount) +} + private[sql] class TimestampColumnStats extends ColumnStats { var upper: Timestamp = null var lower: Timestamp = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 198b5756676aa..ab66c85c4f242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag @@ -335,7 +335,26 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { +private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { + override def extract(buffer: ByteBuffer) = { + val date = new Date(buffer.getLong()) + date + } + + override def append(v: Date, buffer: ByteBuffer): Unit = { + buffer.putLong(v.getTime) + } + + override def getField(row: Row, ordinal: Int) = { + row(ordinal).asInstanceOf[Date] + } + + override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + row(ordinal) = value + } +} + +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def extract(buffer: ByteBuffer) = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) @@ -376,7 +395,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -387,7 +406,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { +private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } @@ -407,6 +426,7 @@ private[sql] object ColumnType { case ShortType => SHORT case StringType => STRING case BinaryType => BINARY + case DateType => DATE case TimestampType => TIMESTAMP case _ => GENERIC } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4f1af7234d551..79e4ddb8c4f5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -295,7 +295,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child) => + case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a9535a750bcd7..61be5ed2db65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.types._ /** * :: DeveloperApi :: @@ -56,6 +57,23 @@ package object debug { case _ => } } + + def typeCheck(): Unit = { + val plan = query.queryExecution.executedPlan + val visited = new collection.mutable.HashSet[TreeNodeRef]() + val debugPlan = plan transform { + case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => + visited += new TreeNodeRef(s) + TypeCheck(s) + } + try { + println(s"Results returned: ${debugPlan.execute().count()}") + } catch { + case e: Exception => + def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) + println(s"Deepest Error: ${unwrap(e)}") + } + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -115,4 +133,71 @@ package object debug { } } } + + /** + * :: DeveloperApi :: + * Helper functions for checking that runtime types match a given schema. + */ + @DeveloperApi + object TypeCheck { + def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { + case (null, _) => + + case (row: Row, StructType(fields)) => + row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + case (s: Seq[_], ArrayType(elemType, _)) => + s.foreach(typeCheck(_, elemType)) + case (m: Map[_, _], MapType(keyType, valueType, _)) => + m.keys.foreach(typeCheck(_, keyType)) + m.values.foreach(typeCheck(_, valueType)) + + case (_: Long, LongType) => + case (_: Int, IntegerType) => + case (_: String, StringType) => + case (_: Float, FloatType) => + case (_: Byte, ByteType) => + case (_: Short, ShortType) => + case (_: Boolean, BooleanType) => + case (_: Double, DoubleType) => + + case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") + } + } + + /** + * :: DeveloperApi :: + * Augments SchemaRDDs with debug methods. + */ + @DeveloperApi + private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { + import TypeCheck._ + + override def nodeName = "" + + /* Only required when defining this class in a REPL. + override def makeCopy(args: Array[Object]): this.type = + TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] + */ + + def output = child.output + + def children = child :: Nil + + def execute() = { + child.execute().map { row => + try typeCheck(row, child.schema) catch { + case e: Exception => + sys.error( + s""" + |ERROR WHEN TYPE CHECKING QUERY + |============================== + |$e + |======== BAD TREE ============ + |$child + """.stripMargin) + } + row + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index d88ab6367a1b3..8fd35880eedfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -49,14 +49,16 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) + val input: Array[Row] = buildPlan.executeCollect() + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + sparkContext.broadcast(hashed) } override def execute() = { val broadcastRelation = Await.result(broadcastFuture, 5.minute) streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) + hashJoin(streamedIter, broadcastRelation.value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 472b2e6ca6b4a..4012d757d5f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -43,34 +43,14 @@ trait HashJoin { override def output = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator = + @transient protected lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - new Iterator[Row] { private[this] var currentStreamedRow: Row = _ private[this] var currentHashMatches: CompactBuffer[Row] = _ @@ -107,7 +87,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) + currentHashMatches = hashedRelation.get(joinKeys.currentValue) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala new file mode 100644 index 0000000000000..38b8993b03f82 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +/** + * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete + * object. + */ +private[joins] sealed trait HashedRelation { + def get(key: Row): CompactBuffer[Row] +} + + +/** + * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. + */ +private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Serializable { + + override def get(key: Row) = hashTable.get(key) +} + + +/** + * A specialized [[HashedRelation]] that maps key into a single value. This implementation + * assumes the key is unique. + */ +private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Serializable { + + override def get(key: Row) = { + val v = hashTable.get(key) + if (v eq null) null else CompactBuffer(v) + } + + def getValue(key: Row): Row = hashTable.get(key) +} + + +// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. + + +private[joins] object HashedRelation { + + def apply( + input: Iterator[Row], + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use Spark's HashMap implementation. + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) + var currentRow: Row = null + + // Whether the join key is unique. If the key is unique, we can convert the underlying + // hash map into one specialized for this. + var keyIsUnique = true + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val rowKey = keyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += currentRow.copy() + } + } + + if (keyIsUnique) { + val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueKeyHashedRelation(uniqHashTable) + } else { + new GeneralHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 8247304c1dc2c..418c1c23e5546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,8 +42,9 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + hashJoin(streamIter, hashed) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 0977da3e8577c..be729e5d244b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } } +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan) = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) +} + /** * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ @DeveloperApi -case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { - val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { def output = child.output :+ resultAttribute } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f513eae9c2d13..e98d151286818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -165,6 +165,16 @@ package object sql { @DeveloperApi val TimestampType = catalyst.types.TimestampType + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Date` values. + * + * @group dataType + */ + @DeveloperApi + val DateType = catalyst.types.DateType + /** * :: DeveloperApi :: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ffb732347d30a..5c6fa78ae3895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) @@ -331,13 +331,21 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val taskId: TaskID = context.getTaskAttemptID.getTaskID + val taskId: TaskID = getTaskAttemptID(context).getTaskID val partition: Int = taskId.getId val filename = s"part-r-${partition + offset}.parquet" val committer: FileOutputCommitter = getOutputCommitter(context).asInstanceOf[FileOutputCommitter] new Path(committer.getWorkPath, filename) } + + // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. + // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions + // are the same, so the method calls are source-compatible but NOT binary-compatible because + // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. + private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { + context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 77353f4eb0227..e44cb08309523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -41,6 +41,7 @@ protected[sql] object DataTypeConversions { case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType + case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType case DecimalType => JDataType.DecimalType case DoubleType => JDataType.DoubleType @@ -80,6 +81,8 @@ protected[sql] object DataTypeConversions { BinaryType case booleanType: org.apache.spark.sql.api.java.BooleanType => BooleanType + case dateType: org.apache.spark.sql.api.java.DateType => + DateType case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a94022c0cf6e3..15f6ba4f72bbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll @@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) } + + test("throw errors for non-aggregate attributes with aggregation") { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } + } + + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index e24c521d24c7a..bfa9ea416266d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -34,6 +34,7 @@ case class ReflectData( byteField: Byte, booleanField: Boolean, decimalField: BigDecimal, + date: Date, timestampField: Timestamp, seqInt: Seq[Int]) @@ -76,7 +77,7 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 0cdbb3167ce36..6bdf741134e2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -30,6 +30,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) def testColumnStats[T <: NativeType, U <: ColumnStats]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4fb1ecf1d532b..3f3f35d50188b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -33,8 +33,8 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - BOOLEAN -> 1, STRING -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, + STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -64,6 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(DATE, new Date(0L), 8) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 38b04dd959f70..a1f21219eaf2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -50,6 +50,7 @@ object ColumnarTestUtils { case STRING => Random.nextString(Random.nextInt(32)) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) + case DATE => new Date(Random.nextLong()) case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 6c9a9ab6c3418..21906e3fdcc6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,9 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index f54a21eb4fbb1..cb73f3da81e24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,9 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala new file mode 100644 index 0000000000000..87c28c334d228 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.debug + +import org.scalatest.FunSuite + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ + +class DebuggingSuite extends FunSuite { + test("SchemaRDD.debug()") { + testData.debug() + } + + test("SchemaRDD.typeCheck()") { + testData.typeCheck() + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala new file mode 100644 index 0000000000000..2aad01ded1acf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +class HashedRelationSuite extends FunSuite { + + // Key is simply the record itself + private val keyProjection = new Projection { + override def apply(row: Row): Row = row + } + + test("GeneralHashedRelation") { + val data = Array(Row(0), Row(1), Row(2), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[GeneralHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(Row(10)) === null) + + val data2 = CompactBuffer[Row](data(2)) + data2 += data(2) + assert(hashed.get(data(2)) == data2) + } + + test("UniqueKeyHashedRelation") { + val data = Array(Row(0), Row(1), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + assert(hashed.get(Row(10)) === null) + + val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] + assert(uniqHashed.getValue(data(0)) == data(0)) + assert(uniqHashed.getValue(data(1)) == data(1)) + assert(uniqHashed.getValue(data(2)) == data(2)) + assert(uniqHashed.getValue(Row(10)) == null) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d68dd090b5e6c..8a72e9d2aef57 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -62,8 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" + // If we haven't found all expected answers... if (next.get() < expectedAnswers.size) { + // If another expected answer is found... if (line.startsWith(expectedAnswers(next.get()))) { + // If all expected answers have been found... if (next.incrementAndGet() == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } @@ -75,11 +78,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val process = (Process(command) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL CLI process exit value: $exitValue") - } - try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => @@ -98,6 +96,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |End CliSuite failure output |=========================== """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() @@ -109,7 +108,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute)( + runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" @@ -120,7 +119,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { -> "Time taken: ", "SELECT COUNT(*) FROM hive_test;" -> "5", - "DROP TABLE hive_test" + "DROP TABLE hive_test;" -> "Time taken: " ) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 38977ff162097..e3b4e45a3d68c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} -import scala.sys.process.{Process, ProcessLogger} - import java.io.File import java.net.ServerSocket import java.sql.{DriverManager, Statement} import java.util.concurrent.TimeoutException +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.Try + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.scalatest.FunSuite @@ -41,25 +41,25 @@ import org.apache.spark.sql.catalyst.util.getTempFilePath class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - private val listeningHost = "localhost" - private val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - private val warehousePath = getTempFilePath("warehouse") - private val metastorePath = getTempFilePath("metastore") - private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) { - val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + val warehousePath = getTempFilePath("warehouse") + val metastorePath = getTempFilePath("metastore") + val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + val listeningHost = "localhost" + val listeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } val command = - s"""$serverScript + s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri @@ -68,29 +68,40 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort """.stripMargin.split("\\s+").toSeq - val serverStarted = Promise[Unit]() + val serverRunning = Promise[Unit]() val buffer = new ArrayBuffer[String]() + val LOGGING_MARK = + s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " + var logTailingProcess: Process = null + var logFilePath: String = null - def captureOutput(source: String)(line: String) { - buffer += s"$source> $line" + def captureLogOutput(line: String): Unit = { + buffer += line if (line.contains("ThriftBinaryCLIService listening on")) { - serverStarted.success(()) + serverRunning.success(()) } } - val process = Process(command).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL Thrift server process exit value: $exitValue") + def captureThriftServerOutput(source: String)(line: String): Unit = { + if (line.startsWith(LOGGING_MARK)) { + logFilePath = line.drop(LOGGING_MARK.length).trim + // Ensure that the log file is created so that the `tail' command won't fail + Try(new File(logFilePath).createNewFile()) + logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") + .run(ProcessLogger(captureLogOutput, _ => ())) + } } + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + captureThriftServerOutput("stdout"), + captureThriftServerOutput("stderr"))) + val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" val user = System.getProperty("user.name") try { - Await.result(serverStarted.future, timeout) + Await.result(serverRunning.future, timeout) val connection = DriverManager.getConnection(jdbcUri, user, "") val statement = connection.createStatement() @@ -122,10 +133,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |End HiveThriftServer2Suite failure output |========================================= """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() - process.destroy() + Process(stopScript).run().exitValue() + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Thread.sleep(3.seconds.toMillis) + Option(logTailingProcess).map(_.destroy()) + Option(logFilePath).map(new File(_).delete()) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 35e9c9939d4b7..463888551a359 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -343,6 +343,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ct_case_insensitive", "database_location", "database_properties", + "date_2", + "date_3", + "date_4", + "date_comparison", + "date_join1", + "date_serde", + "date_udf", "decimal_1", "decimal_4", "decimal_join", @@ -604,8 +611,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "part_inherit_tbl_props", "part_inherit_tbl_props_empty", "part_inherit_tbl_props_with_star", + "partition_date", "partition_schema1", "partition_serde_format", + "partition_type_check", "partition_varchar1", "partition_wise_fileformat4", "partition_wise_fileformat5", @@ -904,6 +913,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union7", "union8", "union9", + "union_date", "union_lateralview", "union_ppr", "union_remove_11", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad3b39f81413..8b5a90159e1bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable +import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -357,7 +358,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, TimestampType, BinaryType) + ShortType, DecimalType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -372,6 +373,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" + case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (other, tpe) if primitiveTypes contains tpe => other.toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index d633c42c6bd67..1977618b4c9f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -39,6 +39,7 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType case c: Class[_] if c == classOf[hadoopIo.Text] => StringType case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType @@ -49,6 +50,7 @@ private[hive] trait HiveInspectors { // java class case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType case c: Class[_] if c == classOf[HiveDecimal] => DecimalType case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType @@ -93,6 +95,7 @@ private[hive] trait HiveInspectors { System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) bytes } + case d: hiveIo.DateWritable => d.get case t: hiveIo.TimestampWritable => t.getTimestamp case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) case list: java.util.List[_] => list.map(unwrap) @@ -108,6 +111,7 @@ private[hive] trait HiveInspectors { case str: String => str case p: java.math.BigDecimal => p case p: Array[Byte] => p + case p: java.sql.Date => p case p: java.sql.Timestamp => p } @@ -147,6 +151,7 @@ private[hive] trait HiveInspectors { case l: Byte => l: java.lang.Byte case b: BigDecimal => new HiveDecimal(b.underlying()) case b: Array[Byte] => b + case d: java.sql.Date => d case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => @@ -173,6 +178,7 @@ private[hive] trait HiveInspectors { case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => @@ -211,6 +217,8 @@ private[hive] trait HiveInspectors { case _: JavaBinaryObjectInspector => BinaryType case _: WritableHiveDecimalObjectInspector => DecimalType case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableDateObjectInspector => DateType + case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType case _: WritableVoidObjectInspector => NullType @@ -238,6 +246,7 @@ private[hive] trait HiveInspectors { case ShortType => shortTypeInfo case StringType => stringTypeInfo case DecimalType => decimalTypeInfo + case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index addd5bed8426d..75a19656af110 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -186,6 +186,7 @@ object HiveMetastoreTypes extends RegexParsers { "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | "decimal" ^^^ DecimalType | + "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType @@ -235,6 +236,7 @@ object HiveMetastoreTypes extends RegexParsers { case LongType => "bigint" case BinaryType => "binary" case BooleanType => "boolean" + case DateType => "date" case DecimalType => "decimal" case TimestampType => "timestamp" case NullType => "void" @@ -303,7 +305,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.getSd.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7cc14dc7a9c9e..2b599157d15d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.sql.Date + import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils @@ -317,6 +319,7 @@ private[hive] object HiveQl { case Token("TOK_STRING", Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType case Token("TOK_TIMESTAMP", Nil) => TimestampType case Token("TOK_BINARY", Nil) => BinaryType case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) @@ -924,6 +927,8 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), DecimalType) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) @@ -1047,6 +1052,9 @@ private[hive] object HiveQl { case ast: ASTNode if ast.getType == HiveParser.StringLiteral => Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 84fafcde63d05..0de29d5cffd0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} @@ -52,7 +53,8 @@ private[hive] class HadoopTableReader( @transient attributes: Seq[Attribute], @transient relation: MetastoreRelation, - @transient sc: HiveContext) + @transient sc: HiveContext, + @transient hiveExtraConf: HiveConf) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,7 +65,7 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf)) + sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) def broadcastedHiveConf = _broadcastedHiveConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a4354c1379c63..9a9e2eda6bcd4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ @@ -71,11 +72,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") setConf("hive.metastore.warehouse.dir", warehousePath) + Utils.registerShutdownDeleteDir(new File(warehousePath)) + Utils.registerShutdownDeleteDir(new File(metastorePath)) } val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") testTempDir.delete() testTempDir.mkdir() + Utils.registerShutdownDeleteDir(testTempDir) // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) @@ -121,8 +125,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - hiveFilesTemp.deleteOnExit() - + Utils.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 577ca928b43b6..5b83b77d80a22 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -64,8 +64,14 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } + // Create a local copy of hiveconf,so that scan specific modifications should not impact + // other queries @transient - private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) + private[this] val hiveExtraConf = new HiveConf(context.hiveconf) + + @transient + private[this] val hadoopReader = + new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -80,10 +86,14 @@ case class HiveTableScan( ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) + val tableDesc = relation.tableDesc + val deserializer = tableDesc.getDeserializerClass.newInstance + deserializer.initialize(hiveConf, tableDesc.getProperties) + // Specifies types and object inspectors of columns to be scanned. val structOI = ObjectInspectorUtils .getStandardObjectInspector( - relation.tableDesc.getDeserializer.getObjectInspector, + deserializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] @@ -97,7 +107,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } - addColumnMetadataToConf(context.hiveconf) + addColumnMetadataToConf(hiveExtraConf) /** * Prunes partitions not involve the query plan. diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 new file mode 100644 index 0000000000000..5a368ab170261 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 @@ -0,0 +1 @@ +2012-01-01 2011-01-01 2011-01-01 00:00:00 2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 new file mode 100644 index 0000000000000..edb4b1f84001b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea new file mode 100644 index 0000000000000..2af0b9ed3a68c --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea @@ -0,0 +1 @@ +true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b new file mode 100644 index 0000000000000..d8dfbf60007bd --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b @@ -0,0 +1 @@ +2001-01-28 2001-02-28 2001-03-28 2001-04-28 2001-05-28 2001-06-28 2001-07-28 2001-08-28 2001-09-28 2001-10-28 2001-11-28 2001-12-28 diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b new file mode 100644 index 0000000000000..4f6a1bc4273e0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b @@ -0,0 +1 @@ +true true true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba new file mode 100644 index 0000000000000..db973ab292d5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba @@ -0,0 +1,137 @@ +2010-10-20 7291 +2010-10-20 3198 +2010-10-20 3014 +2010-10-20 2630 +2010-10-20 1610 +2010-10-20 1599 +2010-10-20 1531 +2010-10-20 1142 +2010-10-20 1064 +2010-10-20 897 +2010-10-20 361 +2010-10-21 7291 +2010-10-21 3198 +2010-10-21 3014 +2010-10-21 2646 +2010-10-21 2630 +2010-10-21 1610 +2010-10-21 1599 +2010-10-21 1531 +2010-10-21 1142 +2010-10-21 1064 +2010-10-21 897 +2010-10-21 361 +2010-10-22 3198 +2010-10-22 3014 +2010-10-22 2646 +2010-10-22 2630 +2010-10-22 1610 +2010-10-22 1599 +2010-10-22 1531 +2010-10-22 1142 +2010-10-22 1064 +2010-10-22 897 +2010-10-22 361 +2010-10-23 7274 +2010-10-23 5917 +2010-10-23 5904 +2010-10-23 5832 +2010-10-23 3171 +2010-10-23 3085 +2010-10-23 2932 +2010-10-23 1805 +2010-10-23 650 +2010-10-23 426 +2010-10-23 384 +2010-10-23 272 +2010-10-24 7282 +2010-10-24 3198 +2010-10-24 3014 +2010-10-24 2646 +2010-10-24 2630 +2010-10-24 2571 +2010-10-24 2254 +2010-10-24 1610 +2010-10-24 1599 +2010-10-24 1531 +2010-10-24 897 +2010-10-24 361 +2010-10-25 7291 +2010-10-25 3198 +2010-10-25 3014 +2010-10-25 2646 +2010-10-25 2630 +2010-10-25 1610 +2010-10-25 1599 +2010-10-25 1531 +2010-10-25 1142 +2010-10-25 1064 +2010-10-25 897 +2010-10-25 361 +2010-10-26 7291 +2010-10-26 3198 +2010-10-26 3014 +2010-10-26 2662 +2010-10-26 2646 +2010-10-26 2630 +2010-10-26 1610 +2010-10-26 1599 +2010-10-26 1531 +2010-10-26 1142 +2010-10-26 1064 +2010-10-26 897 +2010-10-26 361 +2010-10-27 7291 +2010-10-27 3198 +2010-10-27 3014 +2010-10-27 2630 +2010-10-27 1610 +2010-10-27 1599 +2010-10-27 1531 +2010-10-27 1142 +2010-10-27 1064 +2010-10-27 897 +2010-10-27 361 +2010-10-28 7291 +2010-10-28 3198 +2010-10-28 3014 +2010-10-28 2646 +2010-10-28 2630 +2010-10-28 1610 +2010-10-28 1599 +2010-10-28 1531 +2010-10-28 1142 +2010-10-28 1064 +2010-10-28 897 +2010-10-28 361 +2010-10-29 7291 +2010-10-29 3198 +2010-10-29 3014 +2010-10-29 2646 +2010-10-29 2630 +2010-10-29 1610 +2010-10-29 1599 +2010-10-29 1531 +2010-10-29 1142 +2010-10-29 1064 +2010-10-29 897 +2010-10-29 361 +2010-10-30 5917 +2010-10-30 5904 +2010-10-30 3171 +2010-10-30 3085 +2010-10-30 2932 +2010-10-30 2018 +2010-10-30 1805 +2010-10-30 650 +2010-10-30 426 +2010-10-30 384 +2010-10-30 272 +2010-10-31 7282 +2010-10-31 3198 +2010-10-31 2571 +2010-10-31 1610 +2010-10-31 1599 +2010-10-31 1531 +2010-10-31 897 +2010-10-31 361 diff --git a/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 new file mode 100644 index 0000000000000..1b0ea7b9eec84 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 @@ -0,0 +1,137 @@ +2010-10-31 361 +2010-10-31 897 +2010-10-31 1531 +2010-10-31 1599 +2010-10-31 1610 +2010-10-31 2571 +2010-10-31 3198 +2010-10-31 7282 +2010-10-30 272 +2010-10-30 384 +2010-10-30 426 +2010-10-30 650 +2010-10-30 1805 +2010-10-30 2018 +2010-10-30 2932 +2010-10-30 3085 +2010-10-30 3171 +2010-10-30 5904 +2010-10-30 5917 +2010-10-29 361 +2010-10-29 897 +2010-10-29 1064 +2010-10-29 1142 +2010-10-29 1531 +2010-10-29 1599 +2010-10-29 1610 +2010-10-29 2630 +2010-10-29 2646 +2010-10-29 3014 +2010-10-29 3198 +2010-10-29 7291 +2010-10-28 361 +2010-10-28 897 +2010-10-28 1064 +2010-10-28 1142 +2010-10-28 1531 +2010-10-28 1599 +2010-10-28 1610 +2010-10-28 2630 +2010-10-28 2646 +2010-10-28 3014 +2010-10-28 3198 +2010-10-28 7291 +2010-10-27 361 +2010-10-27 897 +2010-10-27 1064 +2010-10-27 1142 +2010-10-27 1531 +2010-10-27 1599 +2010-10-27 1610 +2010-10-27 2630 +2010-10-27 3014 +2010-10-27 3198 +2010-10-27 7291 +2010-10-26 361 +2010-10-26 897 +2010-10-26 1064 +2010-10-26 1142 +2010-10-26 1531 +2010-10-26 1599 +2010-10-26 1610 +2010-10-26 2630 +2010-10-26 2646 +2010-10-26 2662 +2010-10-26 3014 +2010-10-26 3198 +2010-10-26 7291 +2010-10-25 361 +2010-10-25 897 +2010-10-25 1064 +2010-10-25 1142 +2010-10-25 1531 +2010-10-25 1599 +2010-10-25 1610 +2010-10-25 2630 +2010-10-25 2646 +2010-10-25 3014 +2010-10-25 3198 +2010-10-25 7291 +2010-10-24 361 +2010-10-24 897 +2010-10-24 1531 +2010-10-24 1599 +2010-10-24 1610 +2010-10-24 2254 +2010-10-24 2571 +2010-10-24 2630 +2010-10-24 2646 +2010-10-24 3014 +2010-10-24 3198 +2010-10-24 7282 +2010-10-23 272 +2010-10-23 384 +2010-10-23 426 +2010-10-23 650 +2010-10-23 1805 +2010-10-23 2932 +2010-10-23 3085 +2010-10-23 3171 +2010-10-23 5832 +2010-10-23 5904 +2010-10-23 5917 +2010-10-23 7274 +2010-10-22 361 +2010-10-22 897 +2010-10-22 1064 +2010-10-22 1142 +2010-10-22 1531 +2010-10-22 1599 +2010-10-22 1610 +2010-10-22 2630 +2010-10-22 2646 +2010-10-22 3014 +2010-10-22 3198 +2010-10-21 361 +2010-10-21 897 +2010-10-21 1064 +2010-10-21 1142 +2010-10-21 1531 +2010-10-21 1599 +2010-10-21 1610 +2010-10-21 2630 +2010-10-21 2646 +2010-10-21 3014 +2010-10-21 3198 +2010-10-21 7291 +2010-10-20 361 +2010-10-20 897 +2010-10-20 1064 +2010-10-20 1142 +2010-10-20 1531 +2010-10-20 1599 +2010-10-20 1610 +2010-10-20 2630 +2010-10-20 3014 +2010-10-20 3198 +2010-10-20 7291 diff --git a/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 b/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 b/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f new file mode 100644 index 0000000000000..66d2220d06de2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f @@ -0,0 +1 @@ +1 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 b/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f b/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 b/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 b/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 new file mode 100644 index 0000000000000..b61affde4ffce --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 @@ -0,0 +1 @@ +2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 new file mode 100644 index 0000000000000..b7305b903edca --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 @@ -0,0 +1,22 @@ +1064 2010-10-20 1064 2010-10-20 +1142 2010-10-21 1142 2010-10-21 +1599 2010-10-22 1599 2010-10-22 +361 2010-10-23 361 2010-10-23 +897 2010-10-24 897 2010-10-24 +1531 2010-10-25 1531 2010-10-25 +1610 2010-10-26 1610 2010-10-26 +3198 2010-10-27 3198 2010-10-27 +1064 2010-10-28 1064 2010-10-28 +1142 2010-10-29 1142 2010-10-29 +1064 2000-11-20 1064 2000-11-20 +1142 2000-11-21 1142 2000-11-21 +1599 2000-11-22 1599 2000-11-22 +361 2000-11-23 361 2000-11-23 +897 2000-11-24 897 2000-11-24 +1531 2000-11-25 1531 2000-11-25 +1610 2000-11-26 1610 2000-11-26 +3198 2000-11-27 3198 2000-11-27 +1064 2000-11-28 1064 2000-11-28 +1142 2000-11-28 1064 2000-11-28 +1064 2000-11-28 1142 2000-11-28 +1142 2000-11-28 1142 2000-11-28 diff --git a/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c b/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 b/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 b/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c b/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 b/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b b/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 b/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 b/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 b/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 b/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee b/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab b/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 b/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 b/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f b/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 b/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 b/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 b/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 b/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 b/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 b/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c b/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf new file mode 100644 index 0000000000000..16c03e7276fec --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf @@ -0,0 +1,137 @@ +Baltimore New York 2010-10-20 -30.0 1064 +Baltimore New York 2010-10-20 23.0 1142 +Baltimore New York 2010-10-20 6.0 1599 +Chicago New York 2010-10-20 42.0 361 +Chicago New York 2010-10-20 24.0 897 +Chicago New York 2010-10-20 15.0 1531 +Chicago New York 2010-10-20 -6.0 1610 +Chicago New York 2010-10-20 -2.0 3198 +Baltimore New York 2010-10-21 17.0 1064 +Baltimore New York 2010-10-21 105.0 1142 +Baltimore New York 2010-10-21 28.0 1599 +Chicago New York 2010-10-21 142.0 361 +Chicago New York 2010-10-21 77.0 897 +Chicago New York 2010-10-21 53.0 1531 +Chicago New York 2010-10-21 -5.0 1610 +Chicago New York 2010-10-21 51.0 3198 +Baltimore New York 2010-10-22 -12.0 1064 +Baltimore New York 2010-10-22 54.0 1142 +Baltimore New York 2010-10-22 18.0 1599 +Chicago New York 2010-10-22 2.0 361 +Chicago New York 2010-10-22 24.0 897 +Chicago New York 2010-10-22 16.0 1531 +Chicago New York 2010-10-22 -6.0 1610 +Chicago New York 2010-10-22 -11.0 3198 +Baltimore New York 2010-10-23 18.0 272 +Baltimore New York 2010-10-23 -10.0 1805 +Baltimore New York 2010-10-23 6.0 3171 +Chicago New York 2010-10-23 3.0 384 +Chicago New York 2010-10-23 32.0 426 +Chicago New York 2010-10-23 1.0 650 +Chicago New York 2010-10-23 11.0 3085 +Baltimore New York 2010-10-24 12.0 1599 +Baltimore New York 2010-10-24 20.0 2571 +Chicago New York 2010-10-24 10.0 361 +Chicago New York 2010-10-24 113.0 897 +Chicago New York 2010-10-24 -5.0 1531 +Chicago New York 2010-10-24 -17.0 1610 +Chicago New York 2010-10-24 -3.0 3198 +Baltimore New York 2010-10-25 -25.0 1064 +Baltimore New York 2010-10-25 92.0 1142 +Baltimore New York 2010-10-25 106.0 1599 +Chicago New York 2010-10-25 31.0 361 +Chicago New York 2010-10-25 -1.0 897 +Chicago New York 2010-10-25 43.0 1531 +Chicago New York 2010-10-25 6.0 1610 +Chicago New York 2010-10-25 -16.0 3198 +Baltimore New York 2010-10-26 -22.0 1064 +Baltimore New York 2010-10-26 123.0 1142 +Baltimore New York 2010-10-26 90.0 1599 +Chicago New York 2010-10-26 12.0 361 +Chicago New York 2010-10-26 0.0 897 +Chicago New York 2010-10-26 29.0 1531 +Chicago New York 2010-10-26 -17.0 1610 +Chicago New York 2010-10-26 6.0 3198 +Baltimore New York 2010-10-27 -18.0 1064 +Baltimore New York 2010-10-27 49.0 1142 +Baltimore New York 2010-10-27 92.0 1599 +Chicago New York 2010-10-27 148.0 361 +Chicago New York 2010-10-27 -11.0 897 +Chicago New York 2010-10-27 70.0 1531 +Chicago New York 2010-10-27 8.0 1610 +Chicago New York 2010-10-27 21.0 3198 +Baltimore New York 2010-10-28 -4.0 1064 +Baltimore New York 2010-10-28 -14.0 1142 +Baltimore New York 2010-10-28 -14.0 1599 +Chicago New York 2010-10-28 2.0 361 +Chicago New York 2010-10-28 2.0 897 +Chicago New York 2010-10-28 -11.0 1531 +Chicago New York 2010-10-28 3.0 1610 +Chicago New York 2010-10-28 -18.0 3198 +Baltimore New York 2010-10-29 -24.0 1064 +Baltimore New York 2010-10-29 21.0 1142 +Baltimore New York 2010-10-29 -2.0 1599 +Chicago New York 2010-10-29 -12.0 361 +Chicago New York 2010-10-29 -11.0 897 +Chicago New York 2010-10-29 15.0 1531 +Chicago New York 2010-10-29 -18.0 1610 +Chicago New York 2010-10-29 -4.0 3198 +Baltimore New York 2010-10-30 14.0 272 +Baltimore New York 2010-10-30 -1.0 1805 +Baltimore New York 2010-10-30 5.0 3171 +Chicago New York 2010-10-30 -6.0 384 +Chicago New York 2010-10-30 -10.0 426 +Chicago New York 2010-10-30 -5.0 650 +Chicago New York 2010-10-30 -5.0 3085 +Baltimore New York 2010-10-31 -1.0 1599 +Baltimore New York 2010-10-31 -14.0 2571 +Chicago New York 2010-10-31 -25.0 361 +Chicago New York 2010-10-31 -18.0 897 +Chicago New York 2010-10-31 -4.0 1531 +Chicago New York 2010-10-31 -22.0 1610 +Chicago New York 2010-10-31 -15.0 3198 +Cleveland New York 2010-10-30 -23.0 2018 +Cleveland New York 2010-10-30 -12.0 2932 +Cleveland New York 2010-10-29 -4.0 2630 +Cleveland New York 2010-10-29 -19.0 2646 +Cleveland New York 2010-10-29 -12.0 3014 +Cleveland New York 2010-10-28 3.0 2630 +Cleveland New York 2010-10-28 -6.0 2646 +Cleveland New York 2010-10-28 1.0 3014 +Cleveland New York 2010-10-27 16.0 2630 +Cleveland New York 2010-10-27 27.0 3014 +Cleveland New York 2010-10-26 4.0 2630 +Cleveland New York 2010-10-26 -27.0 2646 +Cleveland New York 2010-10-26 -11.0 2662 +Cleveland New York 2010-10-26 13.0 3014 +Cleveland New York 2010-10-25 -4.0 2630 +Cleveland New York 2010-10-25 81.0 2646 +Cleveland New York 2010-10-25 42.0 3014 +Cleveland New York 2010-10-24 5.0 2254 +Cleveland New York 2010-10-24 -11.0 2630 +Cleveland New York 2010-10-24 -20.0 2646 +Cleveland New York 2010-10-24 -9.0 3014 +Cleveland New York 2010-10-23 -21.0 2932 +Cleveland New York 2010-10-22 1.0 2630 +Cleveland New York 2010-10-22 -25.0 2646 +Cleveland New York 2010-10-22 -3.0 3014 +Cleveland New York 2010-10-21 3.0 2630 +Cleveland New York 2010-10-21 29.0 2646 +Cleveland New York 2010-10-21 72.0 3014 +Cleveland New York 2010-10-20 -8.0 2630 +Cleveland New York 2010-10-20 -15.0 3014 +Washington New York 2010-10-23 -25.0 5832 +Washington New York 2010-10-23 -21.0 5904 +Washington New York 2010-10-23 -18.0 5917 +Washington New York 2010-10-30 -27.0 5904 +Washington New York 2010-10-30 -16.0 5917 +Washington New York 2010-10-20 -2.0 7291 +Washington New York 2010-10-21 22.0 7291 +Washington New York 2010-10-23 -16.0 7274 +Washington New York 2010-10-24 -26.0 7282 +Washington New York 2010-10-25 9.0 7291 +Washington New York 2010-10-26 4.0 7291 +Washington New York 2010-10-27 26.0 7291 +Washington New York 2010-10-28 45.0 7291 +Washington New York 2010-10-29 1.0 7291 +Washington New York 2010-10-31 -18.0 7282 diff --git a/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b new file mode 100644 index 0000000000000..83c33400edb47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b @@ -0,0 +1 @@ +0 3333 -3333 -3332 3332 diff --git a/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc new file mode 100644 index 0000000000000..4a2462bb3929b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc @@ -0,0 +1 @@ +NULL 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 new file mode 100644 index 0000000000000..977f0d24c58cc --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 @@ -0,0 +1 @@ +0 3333 -3333 -3333 3333 diff --git a/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 new file mode 100644 index 0000000000000..44d1f45e4eb73 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 @@ -0,0 +1 @@ +1970-01-01 08:00:00 1969-12-31 16:00:00 2013-06-19 07:00:00 2013-06-18 17:00:00 diff --git a/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 new file mode 100644 index 0000000000000..645b71d8d61e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 @@ -0,0 +1 @@ +true true true true diff --git a/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 new file mode 100644 index 0000000000000..51863e9a14e4b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 @@ -0,0 +1 @@ +2010-10-20 diff --git a/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 new file mode 100644 index 0000000000000..4043ee1cbdd40 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 @@ -0,0 +1 @@ +2010-10-31 diff --git a/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd b/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 b/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a b/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b b/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 b/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 new file mode 100644 index 0000000000000..18d17ea11b53e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 @@ -0,0 +1 @@ +1304665200 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe b/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 new file mode 100644 index 0000000000000..7ed6ff82de6bc --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 @@ -0,0 +1 @@ +5 diff --git a/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 new file mode 100644 index 0000000000000..b4de394767536 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 @@ -0,0 +1 @@ +11 diff --git a/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 new file mode 100644 index 0000000000000..64bb6b746dcea --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 @@ -0,0 +1 @@ +30 diff --git a/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 b/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 b/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 b/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 b/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc new file mode 100644 index 0000000000000..051ca3d3c28e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc @@ -0,0 +1,2 @@ +2000-01-01 +2013-08-08 diff --git a/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 new file mode 100644 index 0000000000000..24192eefd2caf --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 @@ -0,0 +1,5 @@ +165 val_165 2000-01-01 2 +238 val_238 2000-01-01 2 +27 val_27 2000-01-01 2 +311 val_311 2000-01-01 2 +86 val_86 2000-01-01 2 diff --git a/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac new file mode 100644 index 0000000000000..91ba621412d72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac @@ -0,0 +1,6 @@ +1 11 2008-01-01 +2 12 2008-01-01 +3 13 2008-01-01 +7 17 2008-01-01 +8 18 2008-01-01 +8 28 2008-01-01 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d b/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd new file mode 100644 index 0000000000000..7941f53d8d4c7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd @@ -0,0 +1,40 @@ +1064 2000-11-20 +1064 2000-11-20 +1142 2000-11-21 +1142 2000-11-21 +1599 2000-11-22 +1599 2000-11-22 +361 2000-11-23 +361 2000-11-23 +897 2000-11-24 +897 2000-11-24 +1531 2000-11-25 +1531 2000-11-25 +1610 2000-11-26 +1610 2000-11-26 +3198 2000-11-27 +3198 2000-11-27 +1064 2000-11-28 +1064 2000-11-28 +1142 2000-11-28 +1142 2000-11-28 +1064 2010-10-20 +1064 2010-10-20 +1142 2010-10-21 +1142 2010-10-21 +1599 2010-10-22 +1599 2010-10-22 +361 2010-10-23 +361 2010-10-23 +897 2010-10-24 +897 2010-10-24 +1531 2010-10-25 +1531 2010-10-25 +1610 2010-10-26 +1610 2010-10-26 +3198 2010-10-27 +3198 2010-10-27 +1064 2010-10-28 +1064 2010-10-28 +1142 2010-10-29 +1142 2010-10-29 diff --git a/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 b/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c b/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2829105f43716..3e100775e4981 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -802,6 +802,9 @@ class HiveQuerySuite extends HiveComparisonTest { clear() } + createQueryTest("select from thrift based table", + "SELECT * from src_thrift") + // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..2a7004e56ef53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } - /** + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..213dff6a76354 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.language.existentials + +import py4j.GatewayServer + +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Interval, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +/** + * Interface for Python callback function which is used to transform RDDs + */ +private[python] trait PythonTransformFunction { + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] +} + +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + +/** + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. + */ +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + } + + // for function.Function2 + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } + + private def writeObject(out: ObjectOutputStream): Unit = { + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) + } +} + +/** + * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. + */ +private[python] object PythonTransformFunctionSerializer { + + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ + private var serializer: PythonTransformFunctionSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } + + def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + serializer.dumps(id) + } + + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) + } +} + +/** + * Helper functions, which are called from Python via Py4J. + */ +private[python] object PythonDStream { + + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) + } + + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + +/** + * Base class for PythonDStream with some common methods + */ +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * Transformed DStream in Python. + */ +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends PythonDStream(parent, pfunc) { + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(rdd, validTime) + } else { + None + } + } +} + +/** + * Transformed from two DStreams in Python. + */ +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * similar to StateDStream + */ +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) + extends PythonDStream(parent, reduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(lastState, rdd, validTime) + } else { + lastState + } + } +} + +/** + * similar to ReducedWindowedDStream + */ +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + val invReduceFunc = new TransformFunction(pinvReduceFunc) + + def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val currentTime = validTime + val current = new Interval(currentTime - windowDuration, currentTime) + val previous = current - slideDuration + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + val previousRDD = getOrCompute(previous.endTime) + + // for small window, reduce once will be better than twice + if (pinvReduceFunc != null && previousRDD.isDefined + && windowDuration >= slideDuration * 5) { + + // subtract the values from old RDDs + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) + } else { + previousRDD + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) + if (newRDDs.size > 0) { + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + } else { + subtracted + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) + if (currentRDDs.size > 0) { + func(None, Some(ssc.sc.union(currentRDDs)), validTime) + } else { + None + } + } + } +} diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5a20532315e59..5c7bca4541222 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -122,7 +122,7 @@ private[spark] class Client( * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. */ override def getClientToken(report: ApplicationReport): String = - Option(report.getClientToken).getOrElse("") + Option(report.getClientToken).map(_.toString).getOrElse("") } object Client { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 14a0386b78978..0efac4ea63702 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -143,7 +143,8 @@ private[spark] trait ClientBase extends Logging { val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", + fs.getDefaultReplication(dst)).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
Property NameDefaultMeaning
spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
spark.executor.extraJavaOptions (none)spark.ui.port 4040 - Port for your application's dashboard, which shows memory and workload data + Port for your application's dashboard, which shows memory and workload data.
spark.akka.heartbeat.pauses6006000 This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause @@ -880,8 +872,8 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.revive.interval 1000 - The interval length for the scheduler to revive the worker resource offers to run tasks. - (in milliseconds) + The interval length for the scheduler to revive the worker resource offers to run tasks + (in milliseconds).
spark.history.fs.logDirectory(none) + Directory that contains application event logs to be loaded by the history server +
spark.history.fs.updateInterval 10