diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 0000000000000..36d932c453b6f --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 4f5eb5e20614d..09b4149c2a439 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2d..c3e0221fb62e3 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class b/bin/spark-class index e29b234afaf96..c03946d92e2e4 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -40,35 +40,46 @@ else fi fi -# Look for the launcher. In non-release mode, add the compiled classes directly to the classpath -# instead of looking for a jar file. -SPARK_LAUNCHER_CP= -if [ -f $SPARK_HOME/RELEASE ]; then - LAUNCHER_DIR="$SPARK_HOME/lib" - num_jars="$(ls -1 "$LAUNCHER_DIR" | grep "^spark-launcher.*\.jar$" | wc -l)" - if [ "$num_jars" -eq "0" -a -z "$SPARK_LAUNCHER_CP" ]; then - echo "Failed to find Spark launcher in $LAUNCHER_DIR." 1>&2 - echo "You need to build Spark before running this program." 1>&2 - exit 1 - fi +# Find assembly jar +SPARK_ASSEMBLY_JAR= +if [ -f "$SPARK_HOME/RELEASE" ]; then + ASSEMBLY_DIR="$SPARK_HOME/lib" +else + ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" +fi - LAUNCHER_JARS="$(ls -1 "$LAUNCHER_DIR" | grep "^spark-launcher.*\.jar$" || true)" - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark launcher jars in $LAUNCHER_DIR:" 1>&2 - echo "$LAUNCHER_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 - fi +num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then + echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 + echo "You need to build Spark before running this program." 1>&2 + exit 1 +fi +ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" +if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 +fi - SPARK_LAUNCHER_CP="${LAUNCHER_DIR}/${LAUNCHER_JARS}" +SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" + +# Verify that versions of java used to build the jars and run Spark are compatible +if [ -n "$JAVA_HOME" ]; then + JAR_CMD="$JAVA_HOME/bin/jar" else - LAUNCHER_DIR="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION" - if [ ! -d "$LAUNCHER_DIR/classes" ]; then - echo "Failed to find Spark launcher classes in $LAUNCHER_DIR." 1>&2 - echo "You need to build Spark before running this program." 1>&2 + JAR_CMD="jar" +fi + +if [ $(command -v "$JAR_CMD") ] ; then + jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) + if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then + echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 + echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 + echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 + echo "or build Spark with Java 6." 1>&2 exit 1 fi - SPARK_LAUNCHER_CP="$LAUNCHER_DIR/classes" fi # The launcher library will print arguments separated by a NULL character, to allow arguments with @@ -77,7 +88,7 @@ fi CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$SPARK_LAUNCHER_CP" org.apache.spark.launcher.Main "$@") +done < <("$RUNNER" -cp "$SPARK_ASSEMBLY_JAR" org.apache.spark.launcher.Main "$@") if [ "${CMD[0]}" = "usage" ]; then "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 37d22215a0e7e..4b3401d745f2a 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if "x%1"=="x" ( @@ -29,31 +28,20 @@ if "x%1"=="x" ( exit /b 1 ) -set LAUNCHER_CP=0 -if exist %SPARK_HOME%\RELEASE goto find_release_launcher +rem Find assembly jar +set SPARK_ASSEMBLY_JAR=0 -rem Look for the Spark launcher in both Scala build directories. The launcher doesn't use Scala so -rem it doesn't really matter which one is picked up. Add the compiled classes directly to the -rem classpath instead of looking for a jar file, since it's very common for people using sbt to use -rem the "assembly" target instead of "package". -set LAUNCHER_CLASSES=%SPARK_HOME%\launcher\target\scala-2.10\classes -if exist %LAUNCHER_CLASSES% ( - set LAUNCHER_CP=%LAUNCHER_CLASSES% +if exist "%SPARK_HOME%\RELEASE" ( + set ASSEMBLY_DIR=%SPARK_HOME%\lib +) else ( + set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% ) -set LAUNCHER_CLASSES=%SPARK_HOME%\launcher\target\scala-2.11\classes -if exist %LAUNCHER_CLASSES% ( - set LAUNCHER_CP=%LAUNCHER_CLASSES% -) -goto check_launcher -:find_release_launcher -for %%d in (%SPARK_HOME%\lib\spark-launcher*.jar) do ( - set LAUNCHER_CP=%%d +for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( + set SPARK_ASSEMBLY_JAR=%%d ) - -:check_launcher -if "%LAUNCHER_CP%"=="0" ( - echo Failed to find Spark launcher JAR. +if "%SPARK_ASSEMBLY_JAR%"=="0" ( + echo Failed to find Spark assembly JAR. echo You need to build Spark before running this program. exit /b 1 ) @@ -64,7 +52,7 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCHER_CP% org.apache.spark.launcher.Main %*"') do ( +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %SPARK_ASSEMBLY_JAR% org.apache.spark.launcher.Main %*"') do ( set SPARK_CMD=%%i ) %SPARK_CMD% diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 14ba37d7c9bd9..013db8df9b363 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -30,7 +30,7 @@ $(function() { stripeSummaryTable(); - $("input:checkbox").click(function() { + $('input[type="checkbox"]').click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); stripeSummaryTable(); @@ -39,15 +39,15 @@ $(function() { $("#select-all-metrics").click(function() { if (this.checked) { // Toggle all un-checked options. - $('input:checkbox:not(:checked)').trigger('click'); + $('input[type="checkbox"]:not(:checked)').trigger('click'); } else { // Toggle all checked options. - $('input:checkbox:checked').trigger('click'); + $('input[type="checkbox"]:checked').trigger('click'); } }); // Trigger a click on the checkbox if a user clicks the label next to it. $("span.additional-metric-title").click(function() { - $(this).parent().find('input:checkbox').trigger('click'); + $(this).parent().find('input[type="checkbox"]').trigger('click'); }); }); diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 21c6e6ffa6666..9385f557c4614 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,10 +17,12 @@ package org.apache.spark +import java.util.concurrent.{Executors, TimeUnit} + import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -129,6 +131,10 @@ private[spark] class ExecutorAllocationManager( // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -173,32 +179,24 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 715f292f03469..5871b8c869f03 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,15 +17,15 @@ package org.apache.spark -import scala.concurrent.duration._ -import scala.collection.mutable +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} -import akka.actor.{Actor, Cancellable} +import scala.collection.mutable import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -37,6 +37,12 @@ private[spark] case class Heartbeat( taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + private[spark] case object ExpireDeadHosts private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) @@ -44,36 +50,65 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext) + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv + + private[spark] var scheduler: TaskScheduler = null // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] + + // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses + // "milliseconds" + private val executorTimeoutMs = sc.conf.getOption("spark.network.timeout").map(_.toLong * 1000). + getOrElse(sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120000)) + + // "spark.network.timeoutInterval" uses "seconds", while + // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" + private val checkTimeoutIntervalMs = + sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). + getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) - private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout", - sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000 - - private val checkTimeoutIntervalMs = sc.conf.getLong("spark.network.timeoutInterval", - sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60)) * 1000 - - private var timeoutCheckingTask: Cancellable = null - - override def preStart(): Unit = { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts) - super.preStart() + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + + private val killExecutorThread = Executors.newSingleThreadExecutor( + Utils.namedThreadFactory("kill-executor-thread")) + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + + override def receive: PartialFunction[Any, Unit] = { case ExpireDeadHosts => expireDeadHosts() + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + executorLastSeen(executorId) = System.currentTimeMillis() + context.reply(response) + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } private def expireDeadHosts(): Unit = { @@ -84,19 +119,27 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule logWarning(s"Removing executor $executorId with no recent heartbeats: " + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + - "timed out after ${now - lastSeenMs} ms")) + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { - sc.killExecutor(executorId) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) } executorLastSeen.remove(executorId) } } } - override def postStop(): Unit = { + override def onStop(): Unit = { if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - super.postStop() + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() } } + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c9426c5de23a2..d65c94e410662 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,11 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithReply[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -350,17 +345,22 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a70be16f77eeb..3f1a7dd99d635 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -32,8 +32,6 @@ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} -import akka.actor.Props - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -95,10 +94,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - @volatile private var stopped: Boolean = false + private val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("Cannot call methods on a stopped SparkContext") } } @@ -227,9 +226,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val appName = conf.get("spark.app.name") private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { + private[spark] val eventLogDir: Option[URI] = { if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) } else { None } @@ -356,11 +357,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val sparkUser = Utils.getCurrentUserName() executorEnvs("SPARK_USER") = sparkUser + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + private val heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver") + + heartbeatReceiver.send(TaskSchedulerIsSet) + @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -433,6 +440,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } /** @@ -446,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -474,9 +484,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -1138,7 +1145,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return whether dynamically adjusting the amount of resources allocated to * this application is supported. This is currently only available for YARN. */ - private[spark] def supportDynamicAllocation = + private[spark] def supportDynamicAllocation = master.contains("yarn") || dynamicAllocationTesting /** @@ -1392,32 +1399,34 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - if (!stopped) { - stopped = true - postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") - } + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return } + + postApplicationEnd() + ui.foreach(_.stop()) + env.metricsSystem.report() + metadataCleaner.cancel() + cleaner.foreach(_.stop()) + executorAllocationManager.foreach(_.stop()) + dagScheduler.stop() + dagScheduler = null + listenerBus.stop() + eventLogger.foreach(_.stop()) + env.rpcEnv.stop(heartbeatReceiver) + progressBar.foreach(_.stop()) + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1479,7 +1488,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite @@ -1892,7 +1901,17 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" + + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2a0c7e756dd3a..0171488e09562 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -34,12 +33,14 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -54,7 +55,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -71,6 +72,9 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -91,7 +95,8 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - actorSystem.shutdown() + rpcEnv.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -236,16 +241,15 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -281,12 +285,14 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + rpcEnv.setupEndpoint(name, endpointCreator) } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) + RpcUtils.makeDriverRef(name, conf, rpcEnv) } } @@ -298,9 +304,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -320,12 +326,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) @@ -377,13 +384,13 @@ object SparkEnv extends Logging { val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } - val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", - new OutputCommitCoordinatorActor(outputCommitCoordinator)) - outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6eb4537d10477..2ec42d3aea169 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.spark.executor.CommitDeniedException import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -104,55 +103,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - - // Called after we have decided to commit - def performCommit(): Unit = { - try { - cmtr.commitTask(taCtxt) - logInfo (s"$taID: Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - - // First, check whether the task's output has already been committed by some other attempt - if (cmtr.needsTaskCommit(taCtxt)) { - // The task output needs to be committed, but we don't know whether some other task attempt - // might be racing to commit the same output partition. Therefore, coordinate with the driver - // in order to determine whether this attempt can commit (see SPARK-4879). - val shouldCoordinateWithDriver: Boolean = { - val sparkConf = SparkEnv.get.conf - // We only need to coordinate with the driver if there are multiple concurrent task - // attempts, which should only occur if speculation is enabled - val speculationEnabled = sparkConf.getBoolean("spark.speculation", false) - // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs - sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) - } - if (shouldCoordinateWithDriver) { - val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID) - if (canCommit) { - performCommit() - } else { - val msg = s"$taID: Not committed because the driver did not authorize commit" - logInfo(msg) - // We need to abort the task so that the driver can reschedule new attempts, if necessary - cmtr.abortTask(taCtxt) - throw new CommitDeniedException(msg, jobID, splitID, attemptID) - } - } else { - // Speculation is disabled or a user has chosen to manually bypass the commit coordination - performCommit() - } - } else { - // Some other attempt committed the output, so we do nothing and signal success - logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}") - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index a023712be1166..8441bb3a3047e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -661,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 18ccd625fc8d1..db4e996feb31c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -192,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8da42934a7d96..8bf0627fc420d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -93,7 +94,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -109,7 +110,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -119,7 +120,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -129,8 +130,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -139,8 +140,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -148,7 +149,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -157,7 +160,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -166,8 +171,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -175,7 +182,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -184,7 +193,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -194,7 +205,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -277,8 +290,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -441,8 +456,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19f4c95fcad74..b1ffba4c546bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -605,7 +605,6 @@ private[spark] object PythonRDD extends Logging { */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1) - serverSocket.setReuseAddress(true) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -615,9 +614,9 @@ private[spark] object PythonRDD extends Logging { try { val sock = serverSocket.accept() val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { + Utils.tryWithSafeFinally { writeIteratorToStream(items, out) - } finally { + } { out.close() } } catch { @@ -863,9 +862,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 74ccfa6d3c9a3..4457c75e8b0fc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -165,7 +165,7 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,10 +175,13 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } @@ -212,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 3d0d68de8f495..b7ae9c1fc0a23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,13 +17,15 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None, + val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) val eventLogCodec: Option[String] = None) extends Serializable { @@ -36,7 +38,7 @@ private[spark] class ApplicationDescription( memoryPerSlave: Int = memoryPerSlave, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir, + eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 65238af2caa24..8d13b2a2cd4f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -89,7 +89,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 80c9c13ddec1e..9d40d8c8fd7a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -118,7 +118,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 32499b3a784a1..f459ed5b3a1a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import akka.serialization.Serialization import org.apache.spark.Logging +import org.apache.spark.util.Utils /** @@ -59,9 +60,9 @@ private[master] class FileSystemPersistenceEngine( val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { out.write(serialized) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 420442f7564cc..b8fd406fb6f9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils /** * A client that submits applications to the standalone Master using a REST protocol. @@ -148,8 +149,11 @@ private[deploy] class StandaloneRestClient extends Logging { conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(json.getBytes(Charsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } readResponse(conn) } @@ -241,7 +245,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } else { val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") - logError("Application submission failed" + failMessage) + logError(s"Application submission failed$failMessage") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index deef6ef9043c6..d1a12b01e78f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,10 +19,9 @@ package org.apache.spark.deploy.worker import java.io.File -import akka.actor._ - import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -39,9 +38,9 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader val userJarUrl = new File(userJar).toURI().toURL() @@ -58,7 +57,7 @@ object DriverWrapper { val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index e0790274d7d3e..83fb991891a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,58 +17,63 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive: PartialFunction[Any, Unit] = { + case e => logWarning(s"Received unexpected message: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b5205d4e997ae..8300f9f2190b9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,39 +21,45 @@ import java.net.URL import java.nio.ByteBuffer import scala.collection.mutable -import scala.concurrent.Await +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } } def extractLogUrls: Map[String, String] = { @@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, userClassPath, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bf3135ef081c1..14f99a464b6e9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,8 +27,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} @@ -88,9 +86,9 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst: Boolean = { @@ -139,7 +137,7 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorEndpoint) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -391,11 +389,8 @@ private[spark] class Executor( } } - private val timeout = AkkaUtils.lookupTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) private val heartbeatReceiverRef = - AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { @@ -426,8 +421,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala similarity index 67% rename from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala rename to core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala index 3e47d13f7545d..cf362f8464735 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive} private[spark] case object TriggerThreadDump /** - * Actor that runs inside of executors to enable driver -> executor RPC. + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + context.reply(Utils.getThreadDump()) } } + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87c2aa481095d..818f7a4c8d422 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,15 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} private[spark] trait SparkHadoopMapRedUtil { @@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil { } } } + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 04eb2bf9ba4ab..6b898bd4bfc1b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -181,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac069..3406a7e97e368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc2..71578d1210fde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9059eb13bb5d8..c1d6971787572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c372845..0d130dd4c7a60 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -48,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -112,8 +113,11 @@ private[spark] object CheckpointRDD extends Logging { } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f6..7021a339e879b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -99,7 +99,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it @@ -120,7 +120,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = split.deps.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 5117ccfabfcc2..0c1b02c07d09f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 03afc289736bb..29ca3e9c4bd04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -191,25 +191,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6fdfdb734d1b8..6afe50161dacd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 6b4f097ea9ae5..05351ba4ff76b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = @@ -995,7 +995,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - try { + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1004,7 +1004,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) @@ -1068,7 +1068,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() var recordsWritten = 0L - try { + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1077,7 +1078,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ddbfd5624e741..d80d94a588346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -316,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -488,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -852,7 +852,7 @@ abstract class RDD[T: ClassTag]( * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -986,14 +986,14 @@ abstract class RDD[T: ClassTag]( combOp: (U, U) => U, depth: Int = 2): U = { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { + if (partitions.length == 0) { return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) } val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size + var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) { @@ -1026,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1061,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1140,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1243,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1489,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1499,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..6afd63d537d75 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -94,10 +94,10 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) new SerializableWritable(rdd.context.hadoopConfiguration)) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index c27f435eb9c5a..e9d745588ee9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,7 +76,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4239e7e22af89..3986645350a82 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index d0be304762e1f..a96b6c3d23454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f2..523aaf2b860b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 0000000000000..e259867c14040 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. + */ +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) + + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. + */ + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. + */ + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. + */ + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. + */ + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String +} + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(self) + } + } +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } + +} + +/** + * Represent a host with a port + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + fromURI(new java.net.URI(uri)) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 0000000000000..652e52f2b2e73 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error +import akka.pattern.{ask => akkaAsk} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private[akka] ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + assert(endpointRef != null) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + + case e: AssociationEvent => + // TODO ignore? + + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") + safelyCall(endpoint) { + processMessage(endpoint, m, sender) + } + + case AkkaFailure(e) => + safelyCall(endpoint) { + throw e + } + + case message: Any => { + logWarning(s"Unknown message: $message") + } + + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) + // Now actorRef can be created safely + endpointRef.init() + endpointRef + } + + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val needReply = m.needReply + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + _sender ! AkkaFailure(e) + } else { + throw e + } + } + } + + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + import actorSystem.dispatcher + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" +} + +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } +} + +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef(conf) with Logging { + + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } + + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }.mapTo[T] + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + +} + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param needReply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index b755d8fb15757..50a69379412d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.CallSite */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: Stage, + val finalStage: ResultStage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: CallSite, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b405bd3338e7c..508fe7b3303ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,14 +23,11 @@ import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -53,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -83,7 +84,7 @@ class DAGScheduler( private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] // Stages we need to run whose parents aren't done @@ -114,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -131,8 +134,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val outputCommitCoordinator = env.outputCommitCoordinator - // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -150,7 +151,7 @@ class DAGScheduler( result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + taskMetrics: TaskMetrics): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } @@ -165,26 +166,23 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. - def executorLost(execId: String) { + def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - def executorAdded(execId: String, host: String) { + def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } @@ -210,40 +208,65 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { + private def getShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + jobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies registerShuffleDependencies(shuffleDep, jobId) // Then register current shuffleDep - val stage = - newOrUsedStage( - shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, - shuffleDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - + stage } } /** - * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation - * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage - * directly. + * Helper function to eliminate some code re-use when creating new stages. */ - private def newStage( + private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + (parentStages, id) + } + + /** + * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in + * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * Production of shuffle map stages should always use newOrUsedShuffleStage, not + * newShuffleMapStage directly. + */ + private def newShuffleMapStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_, _, _]], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: CallSite) - : Stage = - { - val parentStages = getParentStages(rdd, jobId) - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) + callSite: CallSite): ShuffleMapStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, + jobId, callSite, shuffleDep) + + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Create a ResultStage -- either directly for use as a result stage, or as part of the + * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated + * with the provided jobId. + */ + private def newResultStage( + rdd: RDD[_], + numTasks: Int, + jobId: Int, + callSite: CallSite): ResultStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) + stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -255,20 +278,17 @@ class DAGScheduler( * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ - private def newOrUsedStage( - rdd: RDD[_], - numTasks: Int, + private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, - callSite: CallSite) - : Stage = - { - val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val numTasks = rdd.partitions.size + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) } else { @@ -306,26 +326,23 @@ class DAGScheduler( } } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents.toList } - // Find ancestor missing shuffle dependencies and register into shuffleToMapStage - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (!parentsWithNoMapStage.isEmpty) { + while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = - newOrUsedStage( - currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, - currentShufDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(currentShufDep, jobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } - // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] @@ -351,7 +368,7 @@ class DAGScheduler( } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents @@ -382,7 +399,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList @@ -392,7 +409,7 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -412,7 +429,7 @@ class DAGScheduler( * * @param job The job whose state to cleanup. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) @@ -474,8 +491,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -504,15 +520,13 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => { + case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - } case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) @@ -526,9 +540,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -541,12 +553,12 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int) { + def cancelJob(jobId: Int): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId)) } - def cancelJobGroup(groupId: String) { + def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } @@ -554,7 +566,7 @@ class DAGScheduler( /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs() { + def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } @@ -633,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => @@ -675,7 +687,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -702,9 +714,10 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } @@ -722,13 +735,12 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { - var finalStage: Stage = null + properties: Properties) { + var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -773,7 +785,7 @@ class DAGScheduler( if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) - if (missing == Nil) { + if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { @@ -794,13 +806,15 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() + // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { - if (stage.isShuffleMap) { - (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) - } else { - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + stage match { + case stage: ShuffleMapStage => + (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + case stage: ResultStage => + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) } } @@ -830,18 +844,21 @@ class DAGScheduler( try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = - if (stage.isShuffleMap) { - closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() - } else { - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() - } + val taskBinaryBytes: Array[Byte] = stage match { + case stage: ShuffleMapStage => + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + case stage: ResultStage => + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage + + // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") @@ -849,20 +866,22 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } - } else { - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + val tasks: Seq[Task[_]] = stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } + + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } } if (tasks.size > 0) { @@ -873,13 +892,20 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - outputCommitCoordinator.stageEnd(stage.id) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - runningStages -= stage + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + val debugString = stage match { + case stage: ShuffleMapStage => + s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})" + case stage : ResultStage => + s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + } + logDebug(debugString) } } @@ -945,22 +971,6 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTimeMillis()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -968,7 +978,10 @@ class DAGScheduler( stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - stage.resultOfJob match { + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -976,7 +989,7 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - markStageAsFinished(stage) + markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) @@ -988,7 +1001,7 @@ class DAGScheduler( job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => - // TODO: Perhaps we want to mark the stage as failed? + // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } } @@ -997,6 +1010,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId @@ -1004,50 +1018,54 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partitionId, status) + shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { - markStageAsFinished(stage) + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - if (stage.shuffleDep.isDefined) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } + + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + shuffleStage.shuffleDep.shuffleId, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = true) + clearCacheLocs() - if (stage.outputLocs.exists(_ == Nil)) { - // Some tasks had failed; let's resubmit this stage + if (shuffleStage.outputLocs.contains(Nil)) { + // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) + .map(_._2).mkString(", ")) + submitStage(shuffleStage) } else { val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + for (shuffleStage <- waitingStages) { + logInfo("Missing parents for " + shuffleStage + ": " + + getMissingParentStages(shuffleStage)) } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage + for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) + { + newlyRunnable += shuffleStage } waitingStages --= newlyRunnable runningStages ++= newlyRunnable for { - stage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(stage) + shuffleStage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(shuffleStage) } { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage, jobId) + logInfo("Submitting " + shuffleStage + " (" + + shuffleStage.rdd + "), which is now runnable") + submitMissingTasks(shuffleStage, jobId) } } } @@ -1068,7 +1086,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1184,6 +1201,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1204,9 +1241,7 @@ class DAGScheduler( } } - /** - * Fails a job and all stages that are only used by that job, and cleans up relevant state. - */ + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1235,8 +1270,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1254,9 +1288,7 @@ class DAGScheduler( } } - /** - * Return true if one of stage's ancestors is target. - */ + /** Return true if one of stage's ancestors is target. */ private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true @@ -1282,7 +1314,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } visitedRdds.contains(target.rdd) @@ -1312,9 +1344,7 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]) - : Seq[TaskLocation] = - { + visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd,partition))) { @@ -1323,12 +1353,12 @@ class DAGScheduler( } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { + if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { + if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep @@ -1412,7 +1442,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.sc.stop() } - override def onStop() { + override def onStop(): Unit = { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index c0d889360ae99..08e7727db2fde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,21 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { Some(CompressionCodec.createCodec(sparkConf)) @@ -259,13 +259,13 @@ private[spark] object EventLoggingListener extends Logging { * @return A path which consists of file-system-safe characters. */ def getLogPath( - logBaseDir: String, + logBaseDir: URI, appId: String, compressionCodecName: Option[String] = None): String = { val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase // e.g. app_123, app_123.lzf val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") - Utils.resolveURI(logBaseDir).toString.stripSuffix("/") + "/" + logName + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index a3caa9f000c89..7c184b1dcb308 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable -import akka.actor.{ActorRef, Actor} - import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, ActorLogReceive} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -34,8 +32,8 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * policy. * * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is - * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit - * output will be forwarded to the driver's OutputCommitCoordinator. + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. * * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. @@ -43,10 +41,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { // Initialized by SparkEnv - var coordinatorActor: Option[ActorRef] = None - private val timeout = AkkaUtils.askTimeout(conf) - private val maxAttempts = AkkaUtils.numRetries(conf) - private val retryInterval = AkkaUtils.retryWaitMs(conf) + var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int private type PartitionId = Long @@ -64,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + /** * Called by tasks to ask whether they can commit their output to HDFS. * @@ -81,9 +83,9 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { partition: PartitionId, attempt: TaskAttemptId): Boolean = { val msg = AskPermissionToCommitOutput(stage, partition, attempt) - coordinatorActor match { - case Some(actor) => - AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithReply[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") @@ -118,15 +120,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { logInfo( s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") case otherReason => - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") - authorizedCommitters.remove(partition) + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } } } def stop(): Unit = synchronized { - coordinatorActor.foreach(_ ! StopCoordinator) - coordinatorActor = None + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None authorizedCommittersByStage.clear() } @@ -157,16 +161,20 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private[spark] object OutputCommitCoordinator { // This actor is used only for RPC - class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) - extends Actor with ActorLogReceive with Logging { + private[spark] class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + override def receive: PartialFunction[Any, Unit] = { case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - context.stop(self) - sender ! true + stop() + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + context.reply( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) } } } diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala similarity index 57% rename from core/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index 4636c4600a01a..c0f3d5a13d623 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -15,15 +15,26 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.scheduler + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite /** - * This class exists to restrict the visibility of TaskContext setters. + * The ResultStage represents the final stage in a job. */ -private [spark] object TaskContextHelper { +private[spark] class ResultStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + // The active job for this result stage. Will be empty if the job has already finished + // (e.g., because the job was cancelled). + var resultOfJob: Option[ActiveJob] = None - def unset(): Unit = TaskContext.unset() - + override def toString: String = "ResultStage " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala new file mode 100644 index 0000000000000..d02210743484c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.ShuffleDependency +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite + +/** + * The ShuffleMapStage represents the intermediate stages in a job. + */ +private[spark] class ShuffleMapStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite, + val shuffleDep: ShuffleDependency[_, _, _]) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + override def toString: String = "ShuffleMapStage " + id + + var numAvailableOutputs: Long = 0 + + def isAvailable: Boolean = numAvailableOutputs == numPartitions + + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + + def addOutputLoc(partition: Int, status: MapStatus): Unit = { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) { + numAvailableOutputs += 1 + } + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 4cbc6e84a6bdd..5d0ddb8377c33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -47,29 +46,23 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] class Stage( +private[spark] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, val callSite: CallSite) extends Logging { - val isShuffleMap = shuffleDep.isDefined val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ - var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] - private var nextAttemptId = 0 + private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm @@ -77,53 +70,6 @@ private[spark] class Stage( /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ var latestInfo: StageInfo = StageInfo.fromStage(this) - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId @@ -133,11 +79,8 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString: String = "Stage " + id - - override def hashCode(): Int = id - - override def equals(other: Any): Boolean = other match { + override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4d9f940813b8e..8b592867ee31d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be198d..70364cea62a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( executorId: String, + executorRef: RpcEndpointRef, hostPort: String, cores: Int, logUrls: Map[String, String]) @@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5d258d9da4d1a..4c49da87af9dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val addressToExecutorId = new HashMap[RpcAddress, String] + + private val reviveThread = + Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor - - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveInterval, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe76720..26e72c0bff38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a59524e7..0324c9dab910b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ffd4825705755..7eb3fdc19b5b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5a38ad9f2b12c..f72566c370a6f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private val askAmThreadPool = + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e13de0f46ef89..b037a4966ced0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index eb3f999b5b375..70a477a6895cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,17 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.{Executors, TimeUnit} -import scala.concurrent.duration._ -import scala.language.postfixOps - -import akka.actor.{Actor, ActorRef, Props} - +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { - import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private val reviveThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("local-revive-thread")) private var freeCores = totalCores @@ -59,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -87,9 +86,17 @@ private[spark] class LocalActor( } if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) } } + + override def onStop(): Unit = { + reviveThread.shutdownNow() + } } /** @@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index f83bcaa5cc09e..579fb6624e692 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,10 +49,20 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) + if (bufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + + s"2048 mb, got: + $bufferSizeMb mb.") + } + private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + + val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + if (maxBufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index d0178dfde6935..5be3ed771e534 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup { // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { + extends ShuffleBlockResolver with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -175,11 +175,6 @@ class FileShuffleBlockManager(conf: SparkConf) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 87fd161e06c85..a1741e2875c16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -26,6 +26,9 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import IndexShuffleBlockManager.NOOP_REDUCE_ID /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -39,25 +42,18 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { private lazy val blockManager = SparkEnv.get.blockManager private val transportConf = SparkTransportConf.fromSparkConf(conf) - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** @@ -83,24 +79,19 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { + Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L out.writeLong(offset) - for (length <- lengths) { offset += length out.writeLong(offset) } - } finally { + } { out.close() } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index @@ -123,3 +114,11 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { override def stop(): Unit = {} } + +private[spark] object IndexShuffleBlockManager { + // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index b521f0c7fc77e..4342b0d598b16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] -trait ShuffleBlockManager { +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { type ShuffleId = Int /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e1249256..978366d1a1d1b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9be..f6e6fe5defe09 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ + /** Write a sequence of records to this task's output */ def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b34400..2a7df8dd5bd83 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { + override def shuffleBlockResolver: FileShuffleBlockManager = { fileShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808e..0497036192154 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { + override def shuffleBlockResolver: IndexShuffleBlockManager = { indexShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 55ea0f17b156a..a066435df6fb0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -58,8 +58,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) + sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) sorter.insertAll(records) } @@ -67,7 +66,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) @@ -100,3 +99,4 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dff09a75d038..1aa0ef18de118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +186,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +202,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +265,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -301,7 +301,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -439,14 +439,10 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockManager = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } @@ -1219,7 +1215,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index a6f1ebf325a7c..69ac37511e730 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f08b..ceacf043029f3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" val timeout = AkkaUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -67,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -85,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5b5328016124e..28c73a7d543ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,68 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + context.reply(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + context.reply(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + context.reply(getPeers(blockManagerId)) - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + context.reply(memoryStatus) case GetStorageStatus => - sender ! storageStatus + context.reply(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + context.reply(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + context.reply(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + context.reply(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + context.reply(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + context.reply(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + context.reply(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + context.reply(true) case StopBlockManagerMaster => - sender ! true - context.stop(self) + context.reply(true) + stop() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + context.reply(heartbeatReceived(blockManagerId)) - case other => - logWarning("Got unknown message: " + other) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -129,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher + // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) }.toSeq ) } @@ -155,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } @@ -217,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) } } } @@ -247,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* - * Rather than blocking on the block status query, master actor should simply return + * Rather than blocking on the block status query, master endpoint should simply return * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. + * that is also waiting for this master endpoint's response to a previous message. */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -276,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -291,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -308,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -379,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) } } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } } @DeveloperApi @@ -412,7 +409,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveEndpoint: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453edef0..f89d8d7493f7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 52fb896c4e21f..8980fa8eb70e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -17,41 +17,43 @@ package org.apache.spark.storage -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} +import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive /** - * An actor to take commands from the master to execute options. For example, + * An RpcEndpoint to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor( +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { - import context.dispatcher + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { + doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { + doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -59,30 +61,34 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { + doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + context.reply(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + context.sendFailure(t) } } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index f703e50b6b0ac..0dfc91dfaff85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -23,6 +23,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -140,14 +141,17 @@ private[spark] class DiskBlockObjectWriter( override def close() { if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - callWithTiming { - fos.getFD.sync() + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + callWithTiming { + fos.getFD.sync() + } } + } { + objOut.close() } - objOut.close() channel = null bs = null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 12cd8ea3bdf1f..2883137872600 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,6 +47,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private val shutdownHook = addShutdownHook() @@ -61,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -91,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff168791..4b232ae7d3180 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -46,10 +46,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +78,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +109,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +125,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 0186eb30a1905..034525b56f59c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 67f572e79314d..77c0bc8b5360a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -65,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -81,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 48a6ede05e17b..6c2c5261306e7 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -179,7 +179,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b84424..e9b2b8d24b476 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala new file mode 100644 index 0000000000000..6665b17c3d5df --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} + +object RpcUtils { + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0b5a914e7dbbf..0fdfaf300e95d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -313,7 +313,7 @@ private[spark] object Utils extends Logging { transferToEnabled: Boolean = false): Long = { var count = 0L - try { + tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. @@ -353,7 +353,7 @@ private[spark] object Utils extends Logging { } } count - } finally { + } { if (closeStreams) { try { in.close() @@ -1214,6 +1214,54 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + // It would be nice to find a method on Try that did this + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + // We could do originalThrowable.addSuppressed(t), but it's + // not available in JDK 1.6. + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when @@ -2055,7 +2103,7 @@ private[spark] object Utils extends Logging { */ def getCurrentUserName(): String = { Option(System.getenv("SPARK_USER")) - .getOrElse(UserGroupInformation.getCurrentUser().getUserName()) + .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } } @@ -2074,7 +2122,7 @@ private[spark] class RedirectThread( override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - try { + Utils.tryWithSafeFinally { val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { @@ -2082,7 +2130,7 @@ private[spark] class RedirectThread( out.flush() len = in.read(buf) } - } finally { + } { if (propagateEof) { out.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index f79e8e0491ea1..41cb8cfe2afa3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b962c101c91da..035f3767ff554 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -664,6 +664,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** + * Exposed for testing purposes. + * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -673,7 +675,7 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { @@ -726,25 +728,19 @@ private[spark] class ExternalSorter[K, V, C]( // this simple we spill out the current in-memory collection so that everything is in files. spillToPartitionFiles(if (aggregator.isDefined) map else buffer) partitionWriters.foreach(_.commitAndClose()) - var out: FileOutputStream = null - var in: FileInputStream = null + val out = new FileOutputStream(outputFile, true) val writeStartTime = System.nanoTime - try { - out = new FileOutputStream(outputFile, true) + util.Utils.tryWithSafeFinally { for (i <- 0 until numPartitions) { - in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - in.close() - in = null - lengths(i) = size - } - } finally { - if (out != null) { - out.close() - } - if (in != null) { - in.close() + val in = new FileInputStream(partitionWriters(i).fileSegment().file) + util.Utils.tryWithSafeFinally { + lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) + } { + in.close() + } } + } { + out.close() context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } @@ -781,7 +777,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Read a partition file back as an iterator (used in our iterator method) */ - def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { if (writer.isOpen) { writer.commitAndClose() } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index abfcee75728dc..3ded1e4af8742 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,10 +28,20 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("local") @@ -39,18 +49,19 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") val sc0 = new SparkContext(conf) + contexts += sc0 assert(sc0.executorAllocationManager.isDefined) sc0.stop() // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") - intercept[SparkException] { new SparkContext(conf1) } + intercept[SparkException] { contexts += new SparkContext(conf1) } SparkEnv.get.stop() SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") - intercept[SparkException] { new SparkContext(conf2) } + intercept[SparkException] { contexts += new SparkContext(conf2) } SparkEnv.get.stop() SparkContext.clearActiveContext() @@ -665,16 +676,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -688,9 +689,22 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { sustainedSchedulerBacklogTimeout.toString) .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 5fdf6bc2777e3..a69e9b761f9a7 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io._ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} -import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLException import com.google.common.io.ByteStreams import org.apache.commons.io.{FileUtils, IOUtils} @@ -228,7 +228,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { try { server.initialize() - intercept[SSLHandshakeException] { + intercept[SSLException] { fileTransferTest(server) } } finally { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000000000..0fd570e5297d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ccfe0678cb1c3..6295d34be5ca9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,37 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, isA} import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, + securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { + RpcEnv.create(name, host, port, conf, securityManager) + } + test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,8 +147,8 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch below akka frame size") { @@ -157,19 +157,24 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("spark") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + verify(rpcCallContext).reply(any()) + verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch exceeds akka frame size") { @@ -178,12 +183,11 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("test") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) @@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + verify(rpcCallContext, never()).reply(any()) + verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b07c4d93db4e6..c7301a30d8b11 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files @@ -25,9 +26,11 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable - import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -173,4 +176,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index e908ba604ebed..fcae603c7d18e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -50,7 +50,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers inProgress: Boolean, codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, appId) + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) val logPath = new URI(logUri).getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a1..6a6f29dd613cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,38 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + assert(workerWatcher.isShutDown) + rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + assert(!workerWatcher.isShutDown) + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4cd0f97368ca3..97079382c716f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -235,6 +235,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -248,7 +254,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -259,17 +265,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 0000000000000..5a734ec5ba5ec --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} + +import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{SparkException, SparkConf} + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + env = createRpcEnv(conf, "local", 12345) + } + + override def afterAll(): Unit = { + if(env != null) { + env.shutdown() + } + } + + def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint("send-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + + val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithReply[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + context.reply(msg) + } + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint("ask-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + context.reply(msg) + } + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("ask a message timeout") { + env.setupEndpoint("ask-timeout", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => { + Thread.sleep(100) + context.reply(msg) + } + } + }) + + val conf = new SparkConf() + conf.set("spark.akka.retry.wait", "0") + conf.set("spark.akka.num.retries", "1") + val anotherEnv = createRpcEnv(conf, "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + try { + val e = intercept[Exception] { + rpcEndpointRef.askWithReply[String]("hello", 1 millis) + } + assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint("onError-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint("self-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive = { + case m => + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => { + self + callSelfSuccessfully = true + } + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in onStop") { + @volatile var selfOption: Option[RpcEndpointRef] = null + + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onStop(): Unit = { + selfOption = Option(self) + } + + override def onError(cause: Throwable): Unit = { + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for(i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run() { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(5 millis)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case m => context.sendFailure(new SparkException("Oops")) + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => context.sendFailure(new SparkException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("network events") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress))) + } + } + + test("sendWithReply: unserializable error") { + env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + intercept[TimeoutException] { + Await.result(f, 1 seconds) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + +} + +class UnserializableClass + +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 0000000000000..58214c0637235 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") + assert("akka.tcp://local@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.shutdown() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 63360a0f189a3..eb759f0807a17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -783,6 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } // Nothing in this test should break if the task info's fields are null, but diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 448258a754153..30ee63e78d9d8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -61,7 +61,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -95,7 +95,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef } test("Log overwriting") { - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, "test") + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") val logPath = new URI(logUri).getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() @@ -107,16 +107,19 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Event log name") { // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath("/base-dir", "app1")) + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) // with compression assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath("/base-dir", "app1", Some("lzf"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) // illegal characters in app ID assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1")) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) // illegal characters in app ID with compression assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1", Some("lz4"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) } /* ----------------- * @@ -137,7 +140,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -173,12 +176,15 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath - val expectedLogDir = testDir.toURI().toString() + val expectedLogDir = testDir.toURI() assert(eventLogPath === EventLoggingListener.getLogPath( expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index c8c957856247a..cf97707946706 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -161,6 +161,31 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { } assert(tempDir.list().size === 0) } + + test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { + val stage: Int = 1 + val partition: Long = 2 + val authorizedCommitter: Long = 3 + val nonAuthorizedCommitter: Long = 100 + outputCommitCoordinator.stageStart(stage) + assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + // The non-authorized committer fails + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + // New tasks should still not be able to commit because the authorized committer has not failed + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + // The authorized committer now fails, clearing the lock + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + // A new task should now be allowed to become the authorized committer + assert( + outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 601694f57aad0..6de6d2fec622a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -145,7 +146,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 6790388f96603..b834dc0e735eb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -54,7 +54,7 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val shuffleBlockManager = - SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockManager] val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8597997..b4de90b65d545 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5abe..6dc5bc4cb08c4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -148,6 +139,12 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } + test("BlockManagerId.isDriver() backwards-compatibility with legacy driver ids (SPARK-6716)") { + assert(BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(BlockManagerId(SparkContext.LEGACY_DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver) + } + test("master + 1 manager interaction") { store = makeBlockManager(20000) val a1 = new Array[Byte](4000) @@ -357,10 +354,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +780,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6250d50fb7036..bec79fc4dc8f7 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException -import scala.concurrent.Await -import scala.util.{Failure, Try} - -import akka.actor._ - +import akka.actor.ActorNotFound import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId import org.apache.spark.SSLSampleConfigs._ @@ -39,39 +36,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("remote fetch security bad password") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off") { @@ -81,28 +76,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -120,8 +111,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security pass") { @@ -131,15 +122,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -148,13 +138,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = goodconf, securityManager = securityManagerGood) + val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -170,47 +157,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off client") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on") { @@ -218,26 +203,22 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -255,8 +236,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -267,28 +248,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -305,45 +282,43 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on and security enabled - bad credentials") { val conf = sparkSSLConfig() + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + slaveConf.set("spark.rpc", "akka") slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -352,35 +327,30 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) - - result match { - case Failure(ex: ActorNotFound) => - case Failure(ex: TimeoutException) => - case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") + try { + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + fail("should receive either ActorNotFound or TimeoutException") + } catch { + case e: ActorNotFound => + case e: TimeoutException => } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 1026cb2aa7cae..47b535206c949 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -203,4 +203,76 @@ class EventLoopSuite extends FunSuite with Timeouts { assert(!eventLoop.isActive) } } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3a937b637e003..f10aa6b59e1af 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -55,13 +55,14 @@ TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout # To write a PR test: #+ * the file must reside within the dev/tests directory #+ * be an executable bash script -#+ * accept two arguments on the command line, the first being the Github PR long commit -#+ hash and the second the Github SHA1 hash +#+ * accept three arguments on the command line, the first being the Github PR long commit +#+ hash, the second the Github SHA1 hash, and the final the current PR hash #+ * and, lastly, return string output to be included in the pr message output that will #+ be posted to Github PR_TESTS=( "pr_merge_ability" "pr_public_classes" + "pr_new_dependencies" ) function post_message () { @@ -146,34 +147,38 @@ function send_archived_logs () { fi } +# post start message +{ + start_message="\ + [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + post_message "$start_message" +} + # Environment variable to capture PR test output pr_message="" +# Ensure we save off the current HEAD to revert to +current_pr_head="`git rev-parse HEAD`" # Run pull request tests for t in "${PR_TESTS[@]}"; do this_test="${FWDIR}/dev/tests/${t}.sh" - # Ensure the test is a file and is executable - if [ -x "$this_test" ]; then - echo "ghprb: $ghprbActualCommit sha1: $sha1" - this_mssg="`bash \"${this_test}\" \"${ghprbActualCommit}\" \"${sha1}\" 2>/dev/null`" + # Ensure the test can be found and is a file + if [ -f "${this_test}" ]; then + echo "Running test: $t" + this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" # Check if this is the merge test as we submit that note *before* and *after* # the tests run [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" pr_message="${pr_message}\n${this_mssg}" + # Ensure, after each test, that we're back on the current PR + git checkout -f "${current_pr_head}" &>/dev/null + else + echo "Cannot find test ${this_test}." fi done -# post start message -{ - start_message="\ - [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - - start_message="${start_message}\n${merge_note}" - - post_message "$start_message" -} - # run tests { timeout "${TESTS_TIMEOUT}" ./dev/run-tests @@ -222,7 +227,7 @@ done PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" - result_message="${result_message}\n${pr_message}" + result_message="${result_message}${pr_message}" post_message "$result_message" } diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh new file mode 100755 index 0000000000000..370c7cc737bbd --- /dev/null +++ b/dev/tests/pr_new_dependencies.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# Arg3: Current PR Commit Hash +#+ the PR hash for the current commit +# + +ghprbActualCommit="$1" +sha1="$2" +current_pr_head="$3" + +MVN_BIN="build/mvn" +CURR_CP_FILE="my-classpath.txt" +MASTER_CP_FILE="master-classpath.txt" + +# First switch over to the master branch +git checkout master &>/dev/null +# Find and copy all pom.xml files into a *.gate file that we can check +# against through various `git` changes +find -name "pom.xml" -exec cp {} {}.gate \; +# Switch back to the current PR +git checkout "${current_pr_head}" &>/dev/null + +# Check if any *.pom files from the current branch are different from the master +difference_q="" +for p in $(find -name "pom.xml"); do + [[ -f "${p}" && -f "${p}.gate" ]] && \ + difference_q="${difference_q}$(diff $p.gate $p)" +done + +# If no pom files were changed we can easily say no new dependencies were added +if [ -z "${difference_q}" ]; then + echo " * This patch does not change any dependencies." +else + # Else we need to manually build spark to determine what, if any, dependencies + # were added into the Spark assembly jar + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${CURR_CP_FILE} + + # Checkout the master branch to compare against + git checkout master &>/dev/null + + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${MASTER_CP_FILE} + + DIFF_RESULTS="`diff my-classpath.txt master-classpath.txt`" + + if [ -z "${DIFF_RESULTS}" ]; then + echo " * This patch does not change any dependencies." + else + # Pretty print the new dependencies + added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" + removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" + + # Construct the final returned message with proper + return_mssg="" + [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" + if [ -n "${removed_deps}" ]; then + if [ -n "${return_mssg}" ]; then + return_mssg="${return_mssg}\n${removed_deps_text}" + else + return_mssg="${removed_deps_text}" + fi + fi + echo "${return_mssg}" + fi + + # Remove the files we've left over + [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" + [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" + + # Clean up our mess from the Maven builds just in case + ${MVN_BIN} clean &>/dev/null +fi diff --git a/docs/README.md b/docs/README.md index 8a54724c4beae..3773ea25c8b67 100644 --- a/docs/README.md +++ b/docs/README.md @@ -60,7 +60,7 @@ We use Sphinx to generate Python API docs, so you will need to install it by run ## API Docs (Scaladoc and Sphinx) -You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory. +You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as @@ -68,7 +68,7 @@ public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a -jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it +jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index c601d793a2e9a..3f10cb2dc3d2a 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -899,6 +899,8 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { // Transform the values without changing the ids (preserves the internal index) def mapValues[VD2](map: VD => VD2): VertexRDD[VD2] def mapValues[VD2](map: (VertexId, VD) => VD2): VertexRDD[VD2] + // Show only vertices unique to this set based on their VertexId's + def minus(other: RDD[(VertexId, VD)]) // Remove vertices from this set that appear in the other set def diff(other: VertexRDD[VD]): VertexRDD[VD] // Join operators that take advantage of the internal indexing to accelerate joins (substantially) diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index a83472f5be52e..9780ea52c4994 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation and use it for prediction. MLlib supports [multinomial naive -Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), -which is typically used for [document -classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) +and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification] +(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. -Feature values must be nonnegative to represent term frequencies. +feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). +Feature values must be nonnegative. The model type is selected with an optional parameter +"Multinomial" or "Bernoulli" with "Multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -32,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, and output a +smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -51,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0) +val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f5b775da7930a..f4fabb0927b66 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -937,7 +937,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -964,7 +964,7 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better @@ -975,25 +975,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -1006,17 +1006,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1080,7 +1080,7 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. @@ -1090,6 +1090,67 @@ for details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that is grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are not cleaned up from Spark's temporary storage until Spark is stopped, which means that +long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need +to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d9f3eb2b74b18..b7e68d4f71714 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -196,6 +196,15 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.submit.waitAppCompletion + true + + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. + + # Launching Spark on YARN diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 74d8653a8b845..0eed9adacf123 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./bin/spark-class org.apache.spark.deploy.worker.Worker spark://IP:PORT + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -81,6 +81,7 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `sbin/start-slave.sh` - Starts a slave instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of slaves as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. - `sbin/stop-slaves.sh` - Stops all slave instances on the machines specified in the `conf/slaves` file. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c99a0b03442c4..4441d6a000a02 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1406,7 +1406,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load("jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 6d6229625f3f9..262512a639046 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -704,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index c467cd08ed742..879a52cef8ff0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -456,6 +456,13 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -802,7 +809,7 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + if not is_ssh_available(host=i.public_dns_name, opts=opts): return False else: return True diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 47202fde7510b..d89361f324917 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -19,7 +19,7 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3d..f73eac1e2b906 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -40,8 +40,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea1..a55e0dc8d36c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -37,8 +37,8 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d83..32e02eab8b031 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -54,8 +54,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922a..8c01a60844620 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -42,8 +42,8 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3db..772cd897f5140 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -31,7 +31,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e322d4ce5a745..ab6e63deb3c95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -90,7 +90,7 @@ class PRMessage() extends Message[String] with Serializable { } class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = { val hash = key match { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 1f4ca4fbe7778..0bc36ea65e1ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -178,7 +178,9 @@ object MovieLensALS { def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a2..92867b44be138 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -85,13 +85,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817e..751b30ea15782 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd1866..e99d1baa72b9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +45,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +86,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b62671..54d996b8ac990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd1..60e2994431b38 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -187,8 +186,8 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - + override def preferredLocation: Option[String] = Option(host) + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * @@ -198,13 +197,12 @@ class FlumeReceiver( */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e04d4088df7dc..2edea9b5b69ba 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 51d273af8da84..39e6754c81dbf 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 04e65cb3d708c..1b1fc8051d052 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -129,8 +129,9 @@ class DirectKafkaInputDStream[ private[streaming] class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def batchForTime = data.asInstanceOf[mutable.HashMap[ - Time, Array[OffsetRange.OffsetRangeTuple]]] + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } override def update(time: Time) { batchForTime.clear() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 6d465bcb6bfc0..a0b8a0c565210 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -86,7 +86,7 @@ class KafkaRDD[ val part = thePart.asInstanceOf[KafkaRDDPartition] assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { - log.warn(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { @@ -155,7 +155,7 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close() = consumer.close() + override def close(): Unit = consumer.close() override def getNext(): R = { if (iter == null || !iter.hasNext) { @@ -207,7 +207,7 @@ object KafkaRDD { fromOffsets: Map[TopicAndPartition, Long], untilOffsets: Map[TopicAndPartition, LeaderOffset], messageHandler: MessageAndMetadata[K, V] => R - ): KafkaRDD[K, V, U, T, R] = { + ): KafkaRDD[K, V, U, T, R] = { val leaders = untilOffsets.map { case (tp, lo) => tp -> (lo.host, lo.port) }.toMap diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 24d78ecb3a97d..a19a72c58a705 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { msgTopic.publish(message) } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) } } } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da5699..7cf02d85d73d3 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee78..588e6bac7b14a 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index d8be02e2023d5..23430179f12ec 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -62,7 +62,6 @@ object EdgeContext { * , _ + _) * }}} */ - def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]) = + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) } - diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb1439773..058c8c8aa1b24 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -34,12 +34,12 @@ class EdgeDirection private (private val name: String) extends Serializable { override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda44..c8790cac3d8a0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 8494d06b1cdb7..36dc7b0f86c89 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -409,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index ad4bfe077293a..a9f04b559c3d1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -121,6 +121,22 @@ abstract class VertexRDD[VD]( */ def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other an RDD to run the set operation against + */ + def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other a VertexRDD to run the set operation against + */ + def minus(other: VertexRDD[VD]): VertexRDD[VD] + /** * For each vertex present in both `this` and `other`, `diff` returns only those vertices with * differing values; for values that are different, keeps the values from `other`. This is diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af75448374..c561570809253 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 43a3aea0f6196..c88b2f65a86cd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -70,9 +70,9 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038c..1df86449fa0c2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 4fd2548b7faf6..b90f9fa327052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -88,6 +88,21 @@ private[graphx] abstract class VertexPartitionBaseOps this.withMask(newMask) } + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Minus operations on two VertexPartitions with different indexes is slow.") + minus(createUsingIndex(other.iterator)) + } else { + self.withMask(self.mask.andNot(other.mask)) + } + } + + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Iterator[(VertexId, VD)]): Self[VD] = { + minus(createUsingIndex(other)) + } + /** * Hides vertices that are the same between this and other. For vertices that are different, keeps * the values from `other`. The indices of `this` and `other` must be the same. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 125692ddaad83..33ac7b0ed6095 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,9 +71,9 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } @@ -103,6 +103,31 @@ class VertexRDDImpl[VD] private[graphx] ( override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = this.mapVertexPartitions(_.map(f)) + override def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + minus(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + + override def minus (other: VertexRDD[VD]): VertexRDD[VD] = { + other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true) { + (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.minus(otherPart)) + }) + case _ => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.minus(msgs)) + } + ) + } + } + override def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { diff(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958e..859f896039047 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e06515179..2bcf8684b8b8e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 570440ba4441f..042e366a29f58 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -156,7 +156,7 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1fb..e2754ea699da9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 4f7a442ab503d..c9443d11c76cf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -47,6 +47,35 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("minus") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with RDD[(VertexId, VD)]") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB: RDD[(VertexId, Int)] = + sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 5).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + } + } + test("diff") { withSpark { sc => val n = 100 @@ -71,7 +100,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("diff vertices with the non-equal number of partitions") { + test("diff vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 24, 3).map(i => (i.toLong, 0))) val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) @@ -96,7 +125,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("leftJoin vertices with the non-equal number of partitions") { + test("leftJoin vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) val vertexB = VertexRDD( diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 2da5f7278729e..d8279145d8e90 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -86,10 +86,14 @@ public AbstractCommandBuilder() { */ List buildJavaCommand(String extraClassPath) throws IOException { List cmd = new ArrayList(); - if (javaHome == null) { - cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); - } else { + String envJavaHome; + + if (javaHome != null) { cmd.add(join(File.separator, javaHome, "bin", "java")); + } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { + cmd.add(join(File.separator, envJavaHome, "bin", "java")); + } else { + cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); } // Load extra JAVA_OPTS from conf/java-opts, if it exists. @@ -182,59 +186,25 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - String assembly = findAssembly(); + final String assembly = AbstractCommandBuilder.class.getProtectionDomain().getCodeSource(). + getLocation().getPath(); addToClassPath(cp, assembly); - // When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus - // jars do not work if only included in the uber jar as plugin.xml metadata is lost. Both sbt - // and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is built - // with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark - // assembly is built for Hive, before actually populating the CLASSPATH with the jars. - // - // This block also serves as a check for SPARK-1703, when the assembly jar is built with - // Java 7 and ends up with too many files, causing issues with other JDK versions. - boolean needsDataNucleus = false; - JarFile assemblyJar = null; - try { - assemblyJar = new JarFile(assembly); - needsDataNucleus = assemblyJar.getEntry("org/apache/hadoop/hive/ql/exec/") != null; - } catch (IOException ioe) { - if (ioe.getMessage().indexOf("invalid CEN header") >= 0) { - System.err.println( - "Loading Spark jar failed.\n" + - "This is likely because Spark was compiled with Java 7 and run\n" + - "with Java 6 (see SPARK-1703). Please use Java 7 to run Spark\n" + - "or build Spark with Java 6."); - System.exit(1); - } else { - throw ioe; - } - } finally { - if (assemblyJar != null) { - try { - assemblyJar.close(); - } catch (IOException e) { - // Ignore. - } - } + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + } else { + libdir = new File(sparkHome, "lib_managed/jars"); } - if (needsDataNucleus) { - System.err.println("Spark assembly has been built with Hive, including Datanucleus jars " + - "in classpath."); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - } else { - libdir = new File(sparkHome, "lib_managed/jars"); - } - - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); - } + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); } } @@ -270,7 +240,6 @@ String getScalaVersion() { if (scala != null) { return scala; } - String sparkHome = getSparkHome(); File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); @@ -330,30 +299,6 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } - private String findAssembly() { - String sparkHome = getSparkHome(); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); - } - - final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); - FileFilter filter = new FileFilter() { - @Override - public boolean accept(File file) { - return file.isFile() && re.matcher(file.getName()).matches(); - } - }; - File[] assemblies = libdir.listFiles(filter); - checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); - checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); - return assemblies[0].getAbsolutePath(); - } - private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/make-distribution.sh b/make-distribution.sh index 9ed1abfe8c598..738a9c4d69601 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -199,7 +199,6 @@ echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$SPARK_HOME"/launcher/target/spark-launcher_$SCALA_VERSION-$VERSION.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : diff --git a/mllib/pom.xml b/mllib/pom.xml index 4c183543e3fa8..5dfab36c76907 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -64,7 +64,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.11.1 + 0.11.2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21f61d80dd95a..34625745dd0a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold /** @@ -55,6 +55,9 @@ class LogisticRegression /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -67,7 +70,8 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) @@ -180,7 +184,6 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - println(s"LR.predict with threshold: ${paramMap(threshold)}") if (score(features) > paramMap(threshold)) 1 else 0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala new file mode 100644 index 0000000000000..05f91dc9105fe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Normalize a vector to have unit norm using the given p-norm. + */ +@AlphaComponent +class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { + + /** + * Normalization in L^p^ space, p = 2 by default. + * @group param + */ + val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + + /** @group getParam */ + def getP: Double = get(p) + + /** @group setParam */ + def setP(value: Double): this.type = set(p, value) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { + val normalizer = new feature.Normalizer(paramMap(p)) + normalizer.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 5d660d1e151a7..0739fdbfcbaae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params { def getProbabilityCol: String = get(probabilityCol) } +private[ml] trait HasFitIntercept extends Params { + /** + * param for fitting the intercept term, defaults to true + * @group param + */ + val fitIntercept: BooleanParam = + new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) + + /** @group getParam */ + def getFitIntercept: Boolean = get(fitIntercept) +} + private[ml] trait HasThreshold extends Params { /** * param for threshold in (binary) prediction diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 514b4ef98dc5b..52c9e95d6012f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -320,7 +320,7 @@ object ALS extends Logging { /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -332,20 +332,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -391,7 +390,7 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) + fillAtA(ne.ata, lambda) val x = NNLS.solve(ata, ne.atb, workspace) ne.reset() x.map(x => x.toFloat) @@ -420,7 +419,15 @@ object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -429,8 +436,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -444,28 +449,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.length == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -475,7 +465,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -483,7 +472,6 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } @@ -1114,6 +1102,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1121,13 +1110,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1141,7 +1140,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 0000000000000..ecd3b16598438 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e39156734794c..6c386cacfb7ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -58,7 +58,6 @@ import org.apache.spark.util.Utils */ private[python] class PythonMLLibAPI extends Serializable { - /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. * @param jsc Java SparkContext @@ -78,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + if (model.isInstanceOf[LogisticRegressionModel]) { + val lrModel = model.asInstanceOf[LogisticRegressionModel] + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + } else { + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } } finally { data.rdd.unpersist(blocking = false) } @@ -191,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) + .setValidateData(validateData) SVMAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -217,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -243,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable { regType: String, intercept: Boolean, corrections: Int, - tolerance: Double): JList[Object] = { + tolerance: Double, + validateData: Boolean, + numClasses: Int): JList[Object] = { val LogRegAlg = new LogisticRegressionWithLBFGS() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) + .setNumClasses(numClasses) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -346,24 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable { model.predictSoft(data) } - /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python - */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) - - def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) - } - - } /** * Java stub for Python mllib ALS.train(). This stub returns a handle @@ -480,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -520,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } } /** @@ -1117,7 +1119,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e7c3599ff619c..057b628c6a586 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,6 +62,15 @@ class LogisticRegressionModel ( s" but was given weights of length ${weights.size}") } + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ @@ -74,6 +83,7 @@ class LogisticRegressionModel ( * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -84,6 +94,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. */ @Experimental def getThreshold: Option[Double] = threshold @@ -91,6 +102,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ @Experimental def clearThreshold(): this.type = { @@ -106,7 +118,6 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { - require(numFeatures == weightMatrix.size) val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { @@ -114,30 +125,9 @@ class LogisticRegressionModel ( case None => score } } else { - val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - - val weightsArray = weightMatrix match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weightMatrix.getClass}.") - } - - val margins = (0 until numClasses - 1).map { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (dataMatrix.size + 1 == dataWithBiasSize) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - margin - } - /** - * Find the one with maximum margins. If the maxMargin is negative, then the prediction - * result will be the first class. + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. * * PS, if you want to compute the probabilities for each outcome instead of the outcome * with maximum probability, remember to subtract the maxMargin from margins if maxMargin @@ -145,13 +135,20 @@ class LogisticRegressionModel ( */ var bestClass = 0 var maxMargin = 0.0 - var i = 0 - while(i < margins.size) { - if (margins(i) > maxMargin) { - maxMargin = margins(i) + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin bestClass = i + 1 } - i += 1 } bestClass.toDouble } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d60e82c410979..c9b3ff0172e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,9 +21,12 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.numerics.{exp => brzExp, log => brzLog} + import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -32,6 +35,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} + /** * Model for Naive Bayes Classifiers. * @@ -39,11 +43,17 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features + * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { + val theta: Array[Array[Double]], + val modelType: String) + extends ClassificationModel with Serializable with Saveable { + + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, "Multinomial") /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -53,19 +63,19 @@ class NaiveBayesModel private[mllib] ( this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM[Double](theta.length, theta(0).length) - - { - // Need to put an extra pair of braces to prevent Scala treating `i` as a member. - var i = 0 - while (i < theta.length) { - var j = 0 - while (j < theta(i).length) { - brzTheta(i, j) = theta(i)(j) - j += 1 - } - i += 1 - } + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t + + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). + private val (brzNegTheta, brzNegThetaSum) = modelType match { + case "Multinomial" => (None, None) + case "Bernoulli" => + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) + (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -77,22 +87,78 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) + modelType match { + case "Multinomial" => + labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + case "Bernoulli" => + labels (brzArgmax (brzPi + + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + } } override def save(sc: SparkContext, path: String): Unit = { - val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) - NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = "2.0" } object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ - private object SaveLoadV1_0 { + private[mllib] object SaveLoadV2_0 { + + def thisFormatVersion: String = "2.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType = data.getString(3) + new NaiveBayesModel(labels, pi, theta, modelType) + } + + } + + private[mllib] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" @@ -100,7 +166,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" /** Model data for model import/export */ - case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -136,26 +205,32 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { override def load(sc: SparkContext, path: String): NaiveBayesModel = { val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName - (loadedClassName, version) match { + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val model = SaveLoadV1_0.load(sc, path) - assert(model.pi.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, - s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), - s"NaiveBayesModel.load expected $numFeatures features," + - s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") - model + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model } } @@ -167,9 +242,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { - def this() = this(1.0) +class NaiveBayes private ( + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { + + def this(lambda: Double) = this(lambda, "Multinomial") + + def this() = this(1.0, "Multinomial") /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -177,9 +257,24 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } - /** Get the smoothing parameter. Default: 1.0. */ + /** Get the smoothing parameter. */ def getLambda: Double = lambda + /** + * Set the model type using a string (case-sensitive). + * Supported options: "Multinomial" and "Bernoulli". + * (default: Multinomial) + */ + def setModelType(modelType:String): NaiveBayes = { + require(NaiveBayes.supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + this.modelType = modelType + this + } + + /** Get the model type. */ + def getModelType: String = this.modelType + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * @@ -213,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() + val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = modelType match { + case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Bernoulli" => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom @@ -236,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with i += 1 } - new NaiveBayesModel(labels, pi, theta) + new NaiveBayesModel(labels, pi, theta, modelType) } } @@ -244,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * Top-level methods for calling naive Bayes. */ object NaiveBayes { + + /* Set of modelTypes that NaiveBayes supports */ + private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * This version of the method uses a default smoothing parameter of 1.0. * @@ -264,16 +371,40 @@ object NaiveBayes { /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, "Multinomial").run(input) + } + + /** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) + * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle + * discrete count data and can be called by setting the model type to "multinomial". + * For example, it can be used with word counts or TF_IDF vectors of documents. + * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as + * Bernoulli NB. + * + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smoothing parameter + * + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be + * multinomial or bernoulli + */ + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + new NaiveBayes(lambda, modelType).run(input) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index b89f38cf5aba4..7d33df3221fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5e17c8da61134..9d63a08e211bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, normalize} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 59a79e5c6a4ac..b2d9053f70145 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.sql.{SQLContext, Row} /** * Entry in vocabulary @@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable { + private val model: Map[String, Array[Float]]) extends Serializable with Saveable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -432,7 +439,13 @@ class Word2VecModel private[mllib] ( if (norm1 == 0 || norm2 == 0) return 0.0 blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } - + + override protected def formatVersion = "1.0" + + def save(sc: SparkContext, path: String): Unit = { + Word2VecModel.SaveLoadV1_0.save(sc, path, model) + } + /** * Transforms a word to its vector representation * @param word a word @@ -475,7 +488,7 @@ class Word2VecModel private[mllib] ( .tail .toArray } - + /** * Returns a map of words to their vector representations. */ @@ -483,3 +496,71 @@ class Word2VecModel private[mllib] ( model } } + +@Experimental +object Word2VecModel extends Loader[Word2VecModel] { + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" + + case class Data(word: String, vector: Array[Float]) + + def load(sc: SparkContext, path: String): Word2VecModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + + val dataArray = dataFrame.select("word", "vector").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap + new Word2VecModel(word2VecMap) + } + + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val vectorSize = model.values.head.size + val numWords = model.size + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } + sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + } + + override def load(sc: SparkContext, path: String): Word2VecModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedVectorSize = (metadata \ "vectorSize").extract[Int] + val expectedNumWords = (metadata \ "numWords").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, loadedVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path) + val vectorSize = model.getVectors.values.head.size + val numWords = model.getVectors.size + require(expectedVectorSize == vectorSize, + s"Word2VecModel requires each word to be mapped to a vector of size " + + s"$expectedVectorSize, got vector of size $vectorSize") + require(expectedNumWords == numWords, + s"Word2VecModel requires $expectedNumWords words, but got $numWords") + model + case _ => throw new Exception( + s"Word2VecModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 961111507f2c2..9a89a6f3a515f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -531,7 +531,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -540,8 +539,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -555,8 +555,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -564,8 +564,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -577,10 +577,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index ce95c063db970..cea8f3f47307b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A @@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} } /** Java-friendly version of `predictOn`. */ @@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a28..a49153bf73c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + protected var model: Option[LinearRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714a..71fb7f13c39c2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType("Bernoulli") + .setModelType("Multinomial"); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0f..35d8c2e16c6cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala new file mode 100644 index 0000000000000..a18c335952b96 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +private case class DataSet(features: Vector) + +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var data: Array[Vector] = _ + @transient var dataFrame: DataFrame = _ + @transient var normalizer: Normalizer = _ + @transient var l1Normalized: Array[Vector] = _ + @transient var l2Normalized: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + l1Normalized = Array( + Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.12765957, -0.23404255, -0.63829787), + Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), + Vectors.dense(0.625, 0.07894737, 0.29605263), + Vectors.sparse(3, Seq()) + ) + l2Normalized = Array( + Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.184549876, -0.3383414, -0.922749378), + Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), + Vectors.dense(0.897906166, 0.113419726, 0.42532397), + Vectors.sparse(3, Seq()) + ) + + val sqlContext = new SQLContext(sc) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normalized_features") + } + + def collectResult(result: DataFrame): Array[Vector] = { + result.select("normalized_features").collect().map { + case Row(features: Vector) => features + } + } + + def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after normalization.") + } + + test("Normalization with default parameter") { + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l2Normalized) + } + + test("Normalization with setter") { + normalizer.setP(1) + + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l1Normalized) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 0bb06e9e8ac9c..fc7349330cf86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite @@ -68,39 +69,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -108,41 +112,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -154,13 +133,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5a27c7d2309c5..f9fe3e006ccb8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.mllib.classification import scala.util.Random +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} + import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -41,37 +44,48 @@ object NaiveBayesSuite { // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( - pi: Array[Double], // 1XC - theta: Array[Array[Double]], // CXD - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = "Multinomial", + sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 + val xi = modelType match { + case "Bernoulli" => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case "Multinomial" => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") } LabeledPoint(y, Vectors.dense(xi)) } } - private val smallPi = Array(0.5, 0.3, 0.2).map(math.log) + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Bernoulli") - private val smallTheta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 - ).map(_.map(math.log)) - - /** Binary labels, 3 features */ - private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), - theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4))) + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Multinomial") } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -85,6 +99,24 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel) = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -93,19 +125,53 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(nb.getLambda === 3.0) } - test("Naive Bayes") { - val nPoints = 10000 + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 42, "Multinomial") + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + validateModelFit(pi, theta, model) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 17, "Multinomial") + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - val pi = NaiveBayesSuite.smallPi - val theta = NaiveBayesSuite.smallTheta + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + test("Naive Bayes Bernoulli") { + val nPoints = 10000 + val pi = Array(0.5, 0.3, 0.2).map(math.log) + val theta = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 45, "Bernoulli") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD) + val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + validateModelFit(pi, theta, model) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 20, "Bernoulli") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -142,19 +208,41 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } - test("model save/load") { - val model = NaiveBayesSuite.binaryModel + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model, load it back, and compare. + // Save model as version 1.0, load it back, and compare. try { - model.save(sc, path) + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) val sameModel = NaiveBayesModel.load(sc, path) assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) + assert(model.modelType === "Multinomial") } finally { Utils.deleteRecursively(tempDir) } @@ -172,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) } } - // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // If we serialize data directly in the task closure, the size of the serialized task + // would be greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) val predictions = model.predict(examples.map(_.features)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce9249..d50c43d439187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690dbd89..98a98a7599bcb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf5..24fd8df691817 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } diff --git a/network/common/pom.xml b/network/common/pom.xml index 7b51845206f4a..22c738bde6d42 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -80,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a77..0f999f5dfe8d8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..d686a951467cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java similarity index 55% rename from project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala rename to network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java index 3d43c35299555..b525ed69fc9fb 100644 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -15,25 +15,41 @@ * limitations under the License. */ +package org.apache.spark.network; -package org.apache.spark.scalastyle +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; -import java.util.regex.Pattern +public class ByteArrayWritableChannel implements WritableByteChannel { -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} + private final byte[] data; + private int offset; -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList } - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() + @Override + public boolean isOpen() { + return true; + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c7194..860dd6d9b3915 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..ff985096d72d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/pom.xml b/pom.xml index 23bb16130b504..42bd926a2fcb8 100644 --- a/pom.xml +++ b/pom.xml @@ -141,7 +141,7 @@ 2.4.0 2.0.8 3.1.0 - 1.7.6 + 1.7.7 0.7.1 1.8.3 @@ -1265,6 +1265,7 @@ create-source-jar jar-no-fork + test-jar-no-fork @@ -1452,7 +1453,8 @@ ${basedir}/src/test/scala scalastyle-config.xml scalastyle-output.xml - UTF-8 + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} @@ -1472,6 +1474,25 @@ org.scalatest scalatest-maven-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + prepare-package + + test-jar + + + + log4j.properties + + + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56f5dbe53fad4..c2d828f982fe0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,20 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ac37c605de4b6..d3faa551a4b14 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -360,15 +360,15 @@ object Unidoc { packages .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/execution"))) - .map(_.filterNot(_.getCanonicalPath.contains("sql/hive/test"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } lazy val settings = scalaJavaUnidocSettings ++ Seq ( diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905e..7096b0d3ee7de 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da415..471d00bd8223f 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/python/docs/index.rst b/python/docs/index.rst index d150de9d5c502..f7eede9c3c82a 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 7890d9dcaac21..50822c93faba1 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -10,7 +10,7 @@ Module contents :show-inheritance: pyspark.streaming.kafka module ----------------------------- +------------------------------ .. automodule:: pyspark.streaming.kafka :members: :undoc-members: diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 6766f3ebb8894..2466e8ac43458 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -22,7 +22,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper from pyspark.mllib.util import Saveable, Loader, inherit_doc @@ -31,13 +31,13 @@ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LinearBinaryClassificationModel(LinearModel): +class LinearClassificationModel(LinearModel): """ - Represents a linear binary classification model that predicts to whether an - example is positive (1.0) or negative (0.0). + A private abstract class representing a multiclass classification model. + The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): - super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None def setThreshold(self, value): @@ -47,14 +47,26 @@ def setThreshold(self, value): Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater than or equal to this threshold is identified as an positive, and negative otherwise. + It is used for binary classification only. """ self._threshold = value + @property + def threshold(self): + """ + .. note:: Experimental + + Returns the threshold (if any) used for converting raw prediction scores + into 0/1 predictions. It is used for binary classification only. + """ + return self._threshold + def clearThreshold(self): """ .. note:: Experimental Clears the threshold so that `predict` will output raw prediction scores. + It is used for binary classification only. """ self._threshold = None @@ -66,7 +78,7 @@ def predict(self, test): raise NotImplementedError -class LogisticRegressionModel(LinearBinaryClassificationModel): +class LogisticRegressionModel(LinearClassificationModel): """A linear binary classification model derived from logistic regression. @@ -112,10 +124,39 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... os.removedirs(path) ... except: ... pass + >>> multi_class_data = [ + ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), + ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), + ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) + ... ] + >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) + >>> mcm.predict([0.0, 0.5, 0.0]) + 0 + >>> mcm.predict([0.8, 0.0, 0.0]) + 1 + >>> mcm.predict([0.0, 0.0, 0.3]) + 2 """ - def __init__(self, weights, intercept): + def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) + self._numFeatures = int(numFeatures) + self._numClasses = int(numClasses) self._threshold = 0.5 + if self._numClasses == 2: + self._dataWithBiasSize = None + self._weightsMatrix = None + else: + self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1) + self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1, + self._dataWithBiasSize) + + @property + def numFeatures(self): + return self._numFeatures + + @property + def numClasses(self): + return self._numClasses def predict(self, x): """ @@ -126,20 +167,38 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self._intercept - if margin > 0: - prob = 1 / (1 + exp(-margin)) + if self.numClasses == 2: + margin = self.weights.dot(x) + self._intercept + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 else: - exp_margin = exp(margin) - prob = exp_margin / (1 + exp_margin) - if self._threshold is None: - return prob - else: - return 1 if prob > self._threshold else 0 + best_class = 0 + max_margin = 0.0 + if x.size + 1 == self._dataWithBiasSize: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i][0:x.size]) + \ + self._weightsMatrix[i][x.size] + if margin > max_margin: + max_margin = margin + best_class = i + 1 + else: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i]) + if margin > max_margin: + max_margin = margin + best_class = i + 1 + return best_class def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( - _py2java(sc, self._coeff), self.intercept) + _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) java_model.save(sc._jsc.sc(), path) @classmethod @@ -148,8 +207,10 @@ def load(cls, sc, path): sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) intercept = java_model.intercept() + numFeatures = java_model.numFeatures() + numClasses = java_model.numClasses() threshold = java_model.getThreshold().get() - model = LogisticRegressionModel(weights, intercept) + model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) model.setThreshold(threshold) return model @@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.01, regType="l2", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False, + validateData=True): """ Train a logistic regression model on the given data. @@ -184,11 +246,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object): @classmethod def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4): + intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -223,6 +288,11 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType update (default: 10). :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) + :param numClasses: The number of classes (i.e., outcomes) a label can take + in Multinomial Logistic Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -237,12 +307,20 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, float(regParam), regType, bool(intercept), int(corrections), - float(tolerance)) - + float(tolerance), bool(validateData), int(numClasses)) + + if initialWeights is None: + if numClasses == 2: + initialWeights = [0.0] * len(data.first().features) + else: + if intercept: + initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) + else: + initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearBinaryClassificationModel): +class SVMModel(LinearClassificationModel): """A support vector machine. @@ -325,7 +403,8 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="l2", + intercept=False, validateData=True): """ Train a support vector machine on the given data. @@ -351,11 +430,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 0ffe092a07365..3cda1205e1391 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -244,6 +244,12 @@ def transform(self, x): x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ @@ -331,6 +337,12 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + class Word2Vec(object): """ @@ -373,6 +385,7 @@ def __init__(self): self.numPartitions = 1 self.numIterations = 1 self.seed = random.randint(0, sys.maxint) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -411,6 +424,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -422,7 +443,8 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), long(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f5aad28afda0f..51c1490b1618d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -173,7 +173,24 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): if isinstance(ar, basestring): @@ -292,6 +309,25 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ @@ -634,6 +670,16 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + return self.values[i + j * self.numRows] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 1a4527b12cef2..c5c4c13dae105 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -82,14 +82,16 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.43... + 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel.predict(2,2) - 0.43... + 0.4... + >>> sameModel.predictAll(testset).collect() + [Rating(... >>> try: ... os.removedirs(path) ... except OSError: @@ -111,6 +113,12 @@ def userFeatures(self): def productFeatures(self): return self.call("getProductFeatures") + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) + class ALS(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 209f1ee473b5b..cd7310a64f4ae 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -167,13 +167,19 @@ def load(cls, sc, path): # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca992..1d83e9d483f8e 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -49,6 +49,12 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 155019638f806..61ef398487c0c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,11 +36,14 @@ else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec +from pyspark.mllib.feature import IDF from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -134,6 +137,13 @@ def test_sparse_vector_indexing(self): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + class ListTests(PySparkTestCase): @@ -347,6 +357,19 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + class VectorUDTTests(PySparkTestCase): @@ -620,6 +643,60 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +class FeatureTest(PySparkTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(PySparkTestCase): + def test_word2vec_setters(self): + data = [ + ["I", "have", "a", "pen"], + ["I", "like", "soccer", "very", "much"], + ["I", "live", "in", "Tokyo"] + ] + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index bf288d76447bd..a7a4d2aaf855b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -286,21 +286,18 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain - calculation. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 100) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -365,13 +362,10 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for - regression. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. :param impurity: Criterion used for information gain calculation. Supported values: "variance". diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c337a43c8a7fc..2d05611321ed6 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -113,6 +113,7 @@ def _parse_memory(s): def _load_from_socket(port, serializer): sock = socket.socket() + sock.settimeout(3) try: sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b9ffd6945ea7e..65abb24eed823 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -16,26 +16,32 @@ # """ -public classes of Spark SQL: +Important classes of Spark SQL and DataFrames: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for :class:`DataFrame` and SQL functionality. - L{DataFrame} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, DataFrames also support SQL. - - L{GroupedData} + A distributed collection of data grouped into named columns. - L{Column} - Column is a DataFrame with a single column. + A column expression in a :class:`DataFrame`. - L{Row} - A Row of data returned by a Spark SQL query. + A row of data in a :class:`DataFrame`. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive. + - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - L{DataFrameNaFunctions} + Methods for handling missing data (null values). + - L{functions} + List of built-in functions available for :class:`DataFrame`. + - L{types} + List of data types available. """ from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 795ef0dbc4c47..c2d81ba804110 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,15 +34,15 @@ except ImportError: has_pandas = False -__all__ = ["SQLContext", "HiveContext"] +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] def _monkey_patch_RDD(sqlCtx): def toDF(self, schema=None, sampleRatio=None): """ - Convert current :class:`RDD` into a :class:`DataFrame` + Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)` + This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -57,23 +57,22 @@ def toDF(self, schema=None, sampleRatio=None): class SQLContext(object): - """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{DataFrame}, register L{DataFrame} as + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. - """ - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. + When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, + which could be used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. - It will add a method called `toDF` to :class:`RDD`, which could be - used to convert an RDD into a DataFrame, it's a shorthand for - :func:`SQLContext.createDataFrame`. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. >>> from datetime import datetime >>> sqlCtx = SQLContext(sc) @@ -118,6 +117,11 @@ def getConf(self, key, defaultValue): """ return self._ssql_ctx.getConf(key, defaultValue) + @property + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration.""" + return UDFRegistration(self) + def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -125,6 +129,10 @@ def registerFunction(self, name, f, returnType=StringType()): When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] @@ -133,6 +141,11 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) @@ -173,63 +186,19 @@ def _inferSchema(self, rdd, samplingRatio=None): return schema def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") - schema = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) + return self.createDataFrame(rdd, None, samplingRatio) def applySchema(self, rdd, schema): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - ::note: - Deprecated in 1.3, use :func:`createDataFrame` instead - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> from pyspark.sql.types import * - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> df = sqlCtx.applySchema(rdd2, schema) - >>> df.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -237,45 +206,27 @@ def applySchema(self, rdd, schema): if not isinstance(schema, StructType): raise TypeError("schema should be StructType, but got %s" % schema) - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + return self.createDataFrame(rdd, schema) def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame. - - `schema` could be :class:`StructType` or a list of column names. + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. - When `schema` is a list of column names, the type of each column - will be inferred from `rdd`. + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. - When `schema` is None, it will try to infer the column name and type - from `rdd`, which should be an RDD of :class:`Row`, or namedtuple, - or dict. + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. - If referring needed, `samplingRatio` is used to determined how many - rows will be used to do referring. The first row will be used if - `samplingRatio` is None. + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame - :param schema: a StructType or list of names of columns + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring - :return: a DataFrame >>> l = [('Alice', 1)] >>> sqlCtx.createDataFrame(l).collect() @@ -323,39 +274,57 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if not isinstance(data, RDD): try: # data could be list, tuple, generator ... - data = self._sc.parallelize(data) + rdd = self._sc.parallelize(data) except Exception: raise ValueError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data if schema is None: - return self.inferSchema(data, samplingRatio) + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): - first = data.first() + first = rdd.first() if not isinstance(first, (list, tuple)): raise ValueError("each row in `rdd` should be list or tuple, " "but got %r" % type(first)) row_cls = Row(*schema) - schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - return self.applySchema(data, schema) + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) - def registerDataFrameAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. + for row in rows: + _verify_type(row, schema) - Temporary tables exist only during the lifetime of this instance of - SQLContext. + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. >>> sqlCtx.registerDataFrameAsTable(df, "table1") """ - if (rdd.__class__ is DataFrame): - df = rdd._jdf - self._ssql_ctx.registerDataFrameAsTable(df, tableName) + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) else: raise ValueError("Can only register DataFrame as table") def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a L{DataFrame}. + """Loads a Parquet file, returning the result as a :class:`DataFrame`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -373,15 +342,10 @@ def parquetFile(self, *paths): return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{DataFrame}. + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -417,13 +381,10 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{DataFrame}. - - If the schema is provided, applies the given schema to this - JSON dataset. + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> df1 = sqlCtx.jsonRDD(json) >>> df1.first() @@ -442,7 +403,6 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): >>> df3 = sqlCtx.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) - """ def func(iterator): @@ -463,11 +423,11 @@ def func(iterator): return DataFrame(df, self) def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a DataFrame. + """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Optionally, a schema can be provided as the schema of the returned DataFrame. """ @@ -493,11 +453,11 @@ def createExternalTable(self, tableName, path=None, source=None, It returns the DataFrame associated with the external table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. - Optionally, a schema can be provided as the schema of the returned DataFrame and + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. """ if path is not None: @@ -518,7 +478,7 @@ def createExternalTable(self, tableName, path=None, source=None, return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{DataFrame} representing the result of the given query. + """Returns a :class:`DataFrame` representing the result of the given query. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") @@ -528,7 +488,7 @@ def sql(self, sqlQuery): return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{DataFrame}. + """Returns the specified table as a :class:`DataFrame`. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.table("table1") @@ -538,12 +498,12 @@ def table(self, tableName): return DataFrame(self._ssql_ctx.table(tableName), self) def tables(self, dbName=None): - """Returns a DataFrame containing names of tables in the given database. + """Returns a :class:`DataFrame` containing names of tables in the given database. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. - The returned DataFrame has two columns, tableName and isTemporary - (a column with BooleanType indicating if a table is a temporary one or not). + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> df2 = sqlCtx.tables() @@ -556,9 +516,9 @@ def tables(self, dbName=None): return DataFrame(self._ssql_ctx.tables(dbName), self) def tableNames(self, dbName=None): - """Returns a list of names of tables in the database `dbName`. + """Returns a list of names of tables in the database ``dbName``. - If `dbName` is not specified, the current database will be used. + If ``dbName`` is not specified, the current database will be used. >>> sqlCtx.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlCtx.tableNames() @@ -585,22 +545,18 @@ def clearCache(self): class HiveContext(SQLContext): - """A variant of Spark SQL that integrates with data stored in Hive. - Configuration for Hive is read from hive-site.xml on the classpath. + Configuration for Hive is read from ``hive-site.xml`` on the classpath. It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. """ def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ SQLContext.__init__(self, sparkContext) - if hiveContext: self._scala_HiveContext = hiveContext @@ -619,6 +575,18 @@ def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlCtx): + self.sqlCtx = sqlCtx + + def register(self, name, f, returnType=StringType()): + return self.sqlCtx.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cb89da7a8ed5..c30326ebd133e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,12 +31,11 @@ from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] class DataFrame(object): - - """A collection of rows that have the same columns. + """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: @@ -50,13 +49,6 @@ class DataFrame(object): ageCol = people.age - Note that the :class:`Column` type can also be manipulated - through its various functions:: - - # The following creates a new column that increases everybody's age by 10. - people.age + 10 - - A more concrete example:: # To create DataFrame using SQLContext @@ -76,9 +68,7 @@ def __init__(self, jdf, sql_ctx): @property def rdd(self): - """ - Return the content of the :class:`DataFrame` as an :class:`RDD` - of :class:`Row` s. + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() @@ -93,8 +83,16 @@ def applySchema(it): return self._lazy_rdd + @property + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + def toJSON(self, use_unicode=False): - """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' @@ -103,10 +101,10 @@ def toJSON(self, use_unicode=False): return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. + """Saves the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a :class:`DataFrame` using the L{SQLContext.parquetFile} method. + a :class:`DataFrame` using :func:`SQLContext.parquetFile`. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -121,8 +119,8 @@ def saveAsParquetFile(self, path): def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this DataFrame. + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") >>> df2 = sqlCtx.sql("select * from people") @@ -132,7 +130,7 @@ def registerTempTable(self, name): self._jdf.registerTempTable(name) def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" + """DEPRECATED: use :func:`registerTempTable` instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) @@ -163,22 +161,19 @@ def _java_save_mode(self, mode): return jmode def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source as a table. + """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended \ - to existing table. - * overwrite: Data in the existing table is expected to be overwritten by \ - the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the \ - :class:`DataFrame` and to not change the existing table. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", @@ -191,18 +186,17 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the `source` and a set of `options`. - If `source` is not specified, the default data source configured by - spark.sql.sources.default will be used. + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: - * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. - * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. - * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of \ - the :class:`DataFrame` and to not change the existing data. + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. """ if path is not None: options["path"] = path @@ -216,8 +210,7 @@ def save(self, path=None, source=None, mode="error", **options): @property def schema(self): - """Returns the schema of this :class:`DataFrame` (represented by - a L{StructType}). + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) @@ -238,11 +231,9 @@ def printSchema(self): print (self._jdf.schema().treeString()) def explain(self, extended=False): - """ - Prints the plans (logical and physical) to the console for - debugging purpose. + """Prints the (logical and physical) plans to the console for debugging purpose. - If extended is False, only prints the physical plan. + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... @@ -264,15 +255,13 @@ def explain(self, extended=False): print self._jdf.queryExecution().executedPlan().toString() def isLocal(self): - """ - Returns True if the `collect` and `take` methods can be run locally + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() def show(self, n=20): - """ - Print the first n rows. + """Prints the first ``n`` rows to the console. >>> df DataFrame[age: int, name: string] @@ -287,11 +276,7 @@ def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the DataFrame, - which supports features such as filter pushdown. + """Returns the number of rows in this :class:`DataFrame`. >>> df.count() 2L @@ -299,10 +284,7 @@ def count(self): return self._jdf.count() def collect(self): - """Return a list that contains all of the rows. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns all the records as a list of :class:`Row`. >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -314,7 +296,7 @@ def collect(self): return [cls(r) for r in rs] def limit(self, num): - """Limit the result count to the number specified. + """Limits the result count to the number specified. >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] @@ -325,10 +307,7 @@ def limit(self, num): return DataFrame(jdf, self.sql_ctx) def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -336,9 +315,9 @@ def take(self, num): return self.limit(num).collect() def map(self, f): - """ Return a new RDD by applying a function to each Row + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. - It's a shorthand for df.rdd.map() + This is a shorthand for ``df.rdd.map()``. >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] @@ -346,10 +325,10 @@ def map(self, f): return self.rdd.map(f) def flatMap(self, f): - """ Return a new RDD by first applying a function to all elements of this, + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. - It's a shorthand for df.rdd.flatMap() + This is a shorthand for ``df.rdd.flatMap()``. >>> df.flatMap(lambda p: p.name).collect() [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] @@ -357,10 +336,9 @@ def flatMap(self, f): return self.rdd.flatMap(f) def mapPartitions(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition. + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. - It's a shorthand for df.rdd.mapPartitions() + This is a shorthand for ``df.rdd.mapPartitions()``. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 @@ -370,10 +348,9 @@ def mapPartitions(self, f, preservesPartitioning=False): return self.rdd.mapPartitions(f, preservesPartitioning) def foreach(self, f): - """ - Applies a function to all rows of this DataFrame. + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreach() + This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): ... print person.name @@ -382,10 +359,9 @@ def foreach(self, f): return self.rdd.foreach(f) def foreachPartition(self, f): - """ - Applies a function to each partition of this DataFrame. + """Applies the ``f`` function to each partition of this :class:`DataFrame`. - It's a shorthand for df.rdd.foreachPartition() + This a shorthand for ``df.rdd.foreachPartition()``. >>> def f(people): ... for person in people: @@ -395,14 +371,14 @@ def foreachPartition(self, f): return self.rdd.foreachPartition(f) def cache(self): - """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - """ Set the storage level to persist its values across operations + """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). @@ -413,7 +389,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): return self def unpersist(self, blocking=True): - """ Mark it as non-persistent, and remove all blocks for it from + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. """ self.is_cached = False @@ -425,8 +401,7 @@ def unpersist(self, blocking=True): # return DataFrame(rdd, self.sql_ctx) def repartition(self, numPartitions): - """ Return a new :class:`DataFrame` that has exactly `numPartitions` - partitions. + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. >>> df.repartition(10).rdd.getNumPartitions() 10 @@ -434,8 +409,7 @@ def repartition(self, numPartitions): return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) def distinct(self): - """ - Return a new :class:`DataFrame` containing the distinct rows in this DataFrame. + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() 2L @@ -443,8 +417,7 @@ def distinct(self): return DataFrame(self._jdf.distinct(), self.sql_ctx) def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this DataFrame. + """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() 1L @@ -456,7 +429,7 @@ def sample(self, withReplacement, fraction, seed=None): @property def dtypes(self): - """Return all column names and their data types as a list. + """Returns all column names and their data types as a list. >>> df.dtypes [('age', 'int'), ('name', 'string')] @@ -465,7 +438,7 @@ def dtypes(self): @property def columns(self): - """ Return all column names as a list. + """Returns all column names as a list. >>> df.columns [u'age', u'name'] @@ -473,13 +446,14 @@ def columns(self): return [f.name for f in self.schema.fields] def join(self, other, joinExprs=None, joinType=None): - """ - Join with another :class:`DataFrame`, using the given join expression. - The following performs a full outer join between `df1` and `df2`. + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join :param joinExprs: Join expression - :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + :param joinType: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] @@ -497,9 +471,9 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column(s). + """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: The columns or expressions used for sorting + :param cols: list of :class:`Column` to sort by. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] @@ -520,8 +494,29 @@ def sort(self, *cols): orderBy = sort + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + cols = ListConverter().convert(cols, + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + return DataFrame(jdf, self.sql_ctx) + def head(self, n=None): - """ Return the first `n` rows or the first row if n is None. + """ + Returns the first ``n`` rows as a list of :class:`Row`, + or the first :class:`Row` if ``n`` is ``None.`` >>> df.head() Row(age=2, name=u'Alice') @@ -534,7 +529,7 @@ def head(self, n=None): return self.take(n) def first(self): - """ Return the first row. + """Returns the first row as a :class:`Row`. >>> df.first() Row(age=2, name=u'Alice') @@ -542,7 +537,7 @@ def first(self): return self.head() def __getitem__(self, item): - """ Return the column by given name + """Returns the column as a :class:`Column`. >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] @@ -562,7 +557,7 @@ def __getitem__(self, item): raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): - """ Return the column by given name + """Returns the :class:`Column` denoted by ``name``. >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] @@ -573,7 +568,11 @@ def __getattr__(self, name): return Column(jc) def select(self, *cols): - """ Selecting a set of expressions. + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] @@ -588,9 +587,9 @@ def select(self, *cols): return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): - """ - Selects a set of SQL expressions. This is a variant of - `select` that accepts SQL expressions. + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] @@ -600,10 +599,12 @@ def selectExpr(self, *expr): return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition, which could be - :class:`Column` expression or string of SQL expression. + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. - where() is an alias for filter(). + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] @@ -626,10 +627,13 @@ def filter(self, condition): where = filter def groupBy(self, *cols): - """ Group the :class:`DataFrame` using the specified columns, + """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() @@ -644,7 +648,7 @@ def groupBy(self, *cols): def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups - (shorthand for df.groupBy.agg()). + (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() [Row(MAX(age)=5)] @@ -678,8 +682,104 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + + This is an alias for ``na.drop()``. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.dropna().show() + age height name + 10 80 Alice + + >>> df4.na.drop().show() + age height name + 10 80 Alice + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + + def fillna(self, value, subset=None): + """Replace null values, alias for ``na.fill()``. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.fillna(50).show() + age height name + 10 80 Alice + 5 50 Bob + 50 50 Tom + 50 50 null + + >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) + cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) + return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + def withColumn(self, colName, col): - """ Return a new :class:`DataFrame` by adding a column. + """Returns a new :class:`DataFrame` by adding a column. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] @@ -687,7 +787,10 @@ def withColumn(self, colName, col): return self.select('*', col.alias(colName)) def withColumnRenamed(self, existing, new): - """ Rename an existing column to a new name + """REturns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] @@ -698,8 +801,9 @@ def withColumnRenamed(self, existing, new): return self.select(*cols) def toPandas(self): - """ - Collect all the rows and return a `pandas.DataFrame`. + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP age name @@ -712,8 +816,7 @@ def toPandas(self): # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): - """ - SchemaRDD is deprecated, please use DataFrame + """SchemaRDD is deprecated, please use :class:`DataFrame`. """ @@ -740,10 +843,9 @@ def _api(self, *args): class GroupedData(object): - """ A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupBy(). + created by :func:`DataFrame.groupBy`. """ def __init__(self, jdf, sql_ctx): @@ -751,14 +853,17 @@ def __init__(self, jdf, sql_ctx): self.sql_ctx = sql_ctx def agg(self, *exprs): - """ Compute aggregates by specifying a map from column name - to aggregate methods. + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. - The available aggregate methods are `avg`, `max`, `min`, - `sum`, `count`. + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. - :param exprs: list or aggregate columns or a map from column - name to aggregate methods. + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() @@ -783,7 +888,7 @@ def agg(self, *exprs): @dfapi def count(self): - """ Count the number of rows for each group. + """Counts the number of records for each group. >>> df.groupBy(df.age).count().collect() [Row(age=2, count=1), Row(age=5, count=1)] @@ -791,8 +896,11 @@ def count(self): @df_varargs_api def mean(self, *cols): - """Compute the average value for each numeric columns - for each group. This is an alias for `avg`. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() [Row(AVG(age)=3.5)] @@ -802,8 +910,11 @@ def mean(self, *cols): @df_varargs_api def avg(self, *cols): - """Compute the average value for each numeric columns - for each group. + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() [Row(AVG(age)=3.5)] @@ -813,8 +924,7 @@ def avg(self, *cols): @df_varargs_api def max(self, *cols): - """Compute the max value for each numeric columns for - each group. + """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() [Row(MAX(age)=5)] @@ -824,8 +934,9 @@ def max(self, *cols): @df_varargs_api def min(self, *cols): - """Compute the min value for each numeric column for - each group. + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() [Row(MIN(age)=2)] @@ -835,8 +946,9 @@ def min(self, *cols): @df_varargs_api def sum(self, *cols): - """Compute the sum for each numeric columns for each - group. + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() [Row(SUM(age)=7)] @@ -985,6 +1097,23 @@ def substr(self, startPos, length): __getslice__ = substr + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + # order asc = _unary_op("asc", "Returns a sort expression based on the" " ascending order of the given column name.") @@ -1025,6 +1154,24 @@ def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext @@ -1040,6 +1187,12 @@ def _test(): globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5873f09ae3275..146ba6f3e0d98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -76,7 +76,7 @@ def _(col): def countDistinct(col, *cols): - """ Return a new Column for distinct count of `col` or `cols` + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] @@ -91,7 +91,7 @@ def countDistinct(col, *cols): def approxCountDistinct(col, rsd=None): - """ Return a new Column for approximate distinct count of `col` + """Returns a new :class:`Column` for approximate distinct count of ``col``. >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] @@ -123,7 +123,8 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf @@ -142,7 +143,7 @@ def __call__(self, *cols): def udf(f, returnType=StringType()): - """Create a user defined function (UDF) + """Creates a :class:`Column` expression representing a user defined function (UDF). >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2720439416682..b3a6a2c6a9229 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,35 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() @@ -415,6 +446,102 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0169028ccc4eb..45eb8b945dcb0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -33,8 +33,7 @@ class DataType(object): - - """Spark SQL DataType""" + """Base class for data types.""" def __repr__(self): return self.__class__.__name__ @@ -67,7 +66,6 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" _instances = {} @@ -79,66 +77,45 @@ def __call__(cls): class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton class NullType(PrimitiveType): + """Null type. - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. + The data type representing None, used for the types that cannot be inferred. """ class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. + """String data type. """ class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. + """Binary (byte array) data type. """ class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. + """Boolean data type. """ class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. + """Date (datetime.date) data type. """ class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. + """Timestamp (datetime.datetime) data type. """ class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. + """Decimal (decimal.Decimal) data type. """ def __init__(self, precision=None, scale=None): @@ -166,80 +143,55 @@ def __repr__(self): class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. + """Double data type, representing double precision floats. """ class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. + """Float data type, representing single precision floats. """ class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. + """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. + """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' class LongType(PrimitiveType): + """Long data type, i.e. a signed 64-bit integer. - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. """ def simpleString(self): return 'bigint' class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. + """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): return 'smallint' class ArrayType(DataType): + """Array data type. - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. """ def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - + """ >>> ArrayType(StringType()) == ArrayType(StringType(), True) True >>> ArrayType(StringType(), False) == ArrayType(StringType()) @@ -268,29 +220,17 @@ def fromJson(cls, json): class MapType(DataType): + """Map data type. - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + Keys in a map data type are not allowed to be null (None). """ def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - + """ >>> (MapType(StringType(), IntegerType()) ... == MapType(StringType(), IntegerType(), True)) True @@ -325,30 +265,16 @@ def fromJson(cls, json): class StructField(DataType): + """A field in :class:`StructType`. - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be serialized to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - + """ >>> (StructField("f1", StringType(), True) ... == StructField("f1", StringType(), True)) True @@ -384,17 +310,13 @@ def fromJson(cls, json): class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - + This is the data type representing a :class:`Row`. """ def __init__(self, fields): - """Creates a StructType - + """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 @@ -425,9 +347,9 @@ def fromJson(cls, json): class UserDefinedType(DataType): - """ + """User-defined type (UDT). + .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). """ @classmethod diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 2fc35309f4ca5..5a6de11afdd3d 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -17,8 +17,14 @@ # limitations under the License. # -# Usage: start-slave.sh -# where is like "spark://localhost:7077" +# Starts a slave on the machine this script is executed on. + +usage="Usage: start-slave.sh where is like spark://localhost:7077" + +if [ $# -lt 2 ]; then + echo $usage + exit 1 +fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 76316a3067c93..4356c03657109 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,6 +17,8 @@ # limitations under the License. # +# Starts a slave instance on each machine specified in the conf/slaves file. + sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 459a5035d4984..7168d5b2a8e26 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,7 +137,7 @@ - + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 34fedead44db3..f9992185a4563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,7 +30,7 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { - def withPosition(line: Option[Int], startPosition: Option[Int]) = { + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { val newException = new AnalysisException(message, line, startPosition) newException.setStackTrace(getStackTrace) newException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d6126c24fc50d..8bfd0471d9c7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -72,6 +72,11 @@ trait ScalaReflection { case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) + case (r: Row, structType: StructType) => + new GenericRow( + r.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) case (other, _) => other } @@ -179,6 +184,8 @@ trait ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case other => + throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ea7d44a3723d1..89f4a19add1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -139,7 +139,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { sortType.? ~ (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) + val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) @@ -316,13 +316,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } - | NULL ^^^ Literal(null, NullType) + | stringLit ^^ {case s => Literal.create(s, StringType) } + | NULL ^^^ Literal.create(null, NullType) ) protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) + ( TRUE ^^^ Literal.create(true, BooleanType) + | FALSE ^^^ Literal.create(false, BooleanType) ) protected lazy val numericLiteral: Parser[Literal] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44eceb0b372e6..119cb9c3a4400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and * a [[FunctionRegistry]]. */ -class Analyzer(catalog: Catalog, - registry: FunctionRegistry, - caseSensitive: Boolean, - maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { +class Analyzer( + catalog: Catalog, + registry: FunctionRegistry, + caseSensitive: Boolean, + maxIterations: Int = 100) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution @@ -139,10 +140,10 @@ class Analyzer(catalog: Catalog, case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal(null, expr.dataType) + Literal.create(null, expr.dataType) case x if x == g.gid => // replace the groupingId with concrete value (the bit mask) - Literal(bitmask, IntegerType) + Literal.create(bitmask, IntegerType) }) result += GroupExpression(substitution) @@ -212,6 +213,12 @@ class Analyzer(catalog: Catalog, case o => o :: Nil } Alias(c.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateStruct(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -252,7 +259,15 @@ class Analyzer(catalog: Catalog, case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - }.head // Only handle first case found, others will be fixed on the next pass. + }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. + sys.error( + s""" + |Failure when resolving conflicting references in Join: + |$plan + | + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + """.stripMargin) + } val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { @@ -340,19 +355,16 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = - AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a })) + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) - val missingInProject = requiredAttributes -- p.output - if (missingInProject.nonEmpty) { + // If this rule was not a no-op, return the transformed plan, otherwise return the original. + if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList.map(_.toAttribute), - Sort(ordering, global, - Project(projectList ++ missingInProject, child))) + Project(p.output, + Sort(resolvedOrdering, global, + Project(projectList ++ missing, child))) } else { - logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") + logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) @@ -364,18 +376,54 @@ class Analyzer(catalog: Catalog, grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) - val missingInAggs = resolved.filterNot(a.outputSet.contains) - logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") - if (missingInAggs.nonEmpty) { + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(resolvedOrdering, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } } + + /** + * Given a child and a grandchild that are present beneath a sort operator, returns + * a resolved sort ordering and a list of attributes that are missing from the child + * but are present in the grandchild. + */ + def resolveAndFindMissing( + ordering: Seq[SortOrder], + child: LogicalPlan, + grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + // Find any attributes that remain unresolved in the sort. + val unresolved: Seq[String] = + ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + + // Create a map from name, to resolved attributes, when the desired name can be found + // prior to the projection. + val resolved: Map[String, NamedExpression] = + unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap + + // Construct a set that contains all of the attributes that we need to evaluate the + // ordering. + val requiredAttributes = AttributeSet(resolved.values) + + // Figure out which ones are missing from the projection, so that we can add them and + // remove them after the sort. + val missingInProject = requiredAttributes -- child.output + + // Now that we have all the attributes we need, reconstruct a resolved ordering. + // It is important to do it here, instead of waiting for the standard resolved as adding + // attributes to the project below can actually introduce ambiquity that was not present + // before. + val resolvedOrdering = ordering.map(_ transform { + case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) + }).asInstanceOf[Seq[SortOrder]] + + (resolvedOrdering, missingInProject.toSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 40472a1cbb3b4..fa02111385c06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.types._ /** * Throws user facing errors when passed invalid queries that fail to analyze. */ -class CheckAnalysis { +trait CheckAnalysis { + self: Analyzer => /** * Override to provide additional checks for correct analysis. @@ -33,17 +34,22 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String): Nothing = { + protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } - def apply(plan: LogicalPlan): Unit = { + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => + if (operator.childrenResolved) { + // Throw errors for specific problems with get field. + operator.resolveChildren(a.name, resolver, throwErrors = true) + } + val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..3aeb964994d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -78,6 +78,7 @@ trait HiveTypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: Division :: + PropagateTypes :: Nil /** @@ -114,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal.create("NaN", StringType) def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -284,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -295,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -310,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -327,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip + + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) - // SUM and AVERAGE are handled by the implementations of those expressions + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 894c3500cf533..35b74024a4cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -30,5 +30,5 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance(): this.type + def newInstance(): LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c61c395cb4bb1..7731336d247db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -44,7 +44,7 @@ package object analysis { } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ - def withPosition[A](t: TreeNode[_])(f: => A) = { + def withPosition[A](t: TreeNode[_])(f: => A): A = { try f catch { case a: AnalysisException => throw a.withPosition(t.origin.line, t.origin.startPosition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index f9ae85a5cfc1b..5345696570b41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -34,7 +34,7 @@ object AttributeSet { def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]): AttributeSet = { + def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) @@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 30da4faa3f1c6..406de38d1c483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -505,7 +505,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private var count: Long = _ private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + private def addFunction(value: Any) = Add(sum, + Cast(Literal.create(value, expr.dataType), calcType)) override def eval(input: Row): Any = { if (count == 0L) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3fd78db297462..3b2b9211268a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -142,3 +142,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def toString: String = s"Array(${children.mkString(",")})" } + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[NamedExpression]) extends Expression { + override type EvaluatedType = Row + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = children.map { child => + StructField(child.name, child.dataType, child.nullable, child.metadata) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: Row): EvaluatedType = { + Row(children.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 19f3fc9c2291a..0e2d593e94124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -41,6 +41,8 @@ object Literal { case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + + def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) } /** @@ -62,7 +64,10 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { +/** + * In order to do type checking, use Literal.create() instead of constructor + */ +case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d1f3d4f4ee9ee..f9161cf34f0c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -35,7 +35,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType: DataType = if (resolved) { + override def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -74,3 +74,26 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } } + +/** + * A predicate that is evaluated to be true if there are at least `n` non-null values. + */ +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = false + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: Row): Boolean = { + var numNonNulls = 0 + var i = 0 + while (i < childrenArray.length && numNonNulls < n) { + if (childrenArray(i).eval(input) != null) { + numNonNulls += 1 + } + i += 1 + } + numNonNulls >= n + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8bba26bc4cf7f..0a275b84086cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -66,7 +66,7 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -172,11 +172,14 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { class GenericRowWithSchema(values: Array[Any], override val schema: StructType) extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -221,6 +224,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: NativeType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3cdca4e9dd2d1..acfbbace608ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -156,12 +156,11 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryExpression => + self: BinaryPredicate => - type EvaluatedType = Any + override type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -184,7 +183,7 @@ trait StringComparison { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.contains(r) } @@ -192,7 +191,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.startsWith(r) } @@ -200,7 +199,7 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.endsWith(r) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..93e69d409cb91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -218,12 +218,12 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) - case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) + case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -235,36 +235,36 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => true } if (newChildren.length == 0) { - Literal(null, e.dataType) + Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren(0) } else { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } } @@ -284,13 +284,13 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(null), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { case Literal(candidate, _) if candidate == v => true case _ => false - } => Literal(true, BooleanType) + } => Literal.create(true, BooleanType) } } } @@ -647,7 +647,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 02f7c26a8ab6e..7967189cacb24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b01a61d7bf8d6..2e9f3aa4ec4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver) + def resolveChildren( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, output, resolver) + def resolve( + name: String, + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(name, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { protected def resolve( name: String, input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { + resolver: Resolver, + throwErrors: Boolean): Option[NamedExpression] = { val parts = name.split("\\.") @@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + try { + + // The foldLeft adds UnresolvedGetField for every remaining parts of the name, + // and aliased it with the last part of the name. + // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias + // the final expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver)) + val aliasName = nestedFields.last + Some(Alias(fieldExprs, aliasName)()) + } catch { + case a: AnalysisException if !throwErrors => None + } // No matches. case Seq() => @@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") + val referenceNames = ambiguousReferences.map(_._1).mkString(", ") throw new AnalysisException( s"Reference '$name' is ambiguous, could be: $referenceNames.") } } + + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + * + * TODO: this code is duplicated from Analyzer and should be refactored to avoid this. + */ + protected def resolveGetField( + expr: Expression, + fieldName: String, + resolver: Resolver): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4d9e41a2b5d85..8633e06093cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -80,7 +80,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -287,7 +287,10 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case object NoRelation extends LeafNode { +/** + * A relation with one row. This is used in "SELECT ..." without a from clause. + */ +case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala index c243be07a91b6..a9d63e784963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -protected[sql] object DataTypeConversions { +private[sql] object DataTypeConversions { def productToRow(product: Product, schema: StructType): Row = { val mutableRow = new GenericMutableRow(product.productArity) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 89278f7dbc806..34270d0ca7cd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -112,4 +112,4 @@ private[sql] object DataTypeParser { } /** The exception thrown from the [[DataTypeParser]]. */ -protected[sql] class DataTypeException(message: String) extends Exception(message) +private[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index e50e9761431f5..6ee24ee0c1913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index d973144de3468..cdf2bc68d9c5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import scala.math._ import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -670,6 +671,10 @@ case class PrecisionInfo(precision: Int, scale: Int) */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -819,6 +824,10 @@ object ArrayType { */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") @@ -857,6 +866,9 @@ case class StructField( nullable: Boolean = true, metadata: Metadata = Metadata.empty) { + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) @@ -923,7 +935,9 @@ object StructType { case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt @@ -1003,6 +1017,9 @@ object StructType { @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -1121,6 +1138,10 @@ case class MapType( keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") builder.append(s"$prefix-- value: ${valueType.typeName} " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 756cd36f05c8c..ee7b14c7a157c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { override val extendedResolutionRules = EliminateSubQueries :: Nil } - val checkAnalysis = new CheckAnalysis - def caseSensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseSensitiveAnalyzer(plan)) + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) def caseInsensitiveAnalyze(plan: LogicalPlan) = - checkAnalysis(caseInsensitiveAnalyzer(plan)) + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -169,6 +182,24 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { "'b'" :: "group by" :: Nil ) + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false override def output = Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d5865..67bec999dfbd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index ecbb54218d457..70aef1cac421a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -127,11 +127,11 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest( Coalesce(Literal(1.0) :: Literal(1) - :: Literal(1.0, FloatType) + :: Literal.create(1.0, FloatType) :: Nil), Coalesce(Cast(Literal(1.0), DoubleType) :: Cast(Literal(1), DoubleType) - :: Cast(Literal(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) ruleTest( Coalesce(Literal(1L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000000..f2f3a84d19380 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index dcfd8b28cb02a..3dbefa40d2808 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -30,7 +30,34 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ -class ExpressionEvaluationSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends FunSuite { + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} + +class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) @@ -84,7 +111,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - checkEvaluation(!Literal(v, BooleanType), answer) + checkEvaluation(!Literal.create(v, BooleanType), answer) } } @@ -128,33 +155,12 @@ class ExpressionEvaluationSuite extends FunSuite { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } } } - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } - test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) @@ -169,12 +175,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("Remainder") { @@ -184,12 +190,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Remainder(Literal(1), Literal(0)), null) checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("INSET") { @@ -216,14 +222,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(MaxOf(1L, 2L), 2L) checkEvaluation(MaxOf(2L, 1L), 2L) - checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } test("LIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType).like("a"), null) - checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) - checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -258,13 +264,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation(Literal.create(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) } test("RLIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal(null, StringType), null) - checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) @@ -375,7 +381,7 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) - checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) + checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) } test("date") { @@ -501,8 +507,8 @@ class ExpressionEvaluationSuite extends FunSuite { } test("array casting") { - val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -550,10 +556,10 @@ class ExpressionEvaluationSuite extends FunSuite { } test("map casting") { - val map = Literal( + val map = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal( + val map_notNull = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) @@ -611,14 +617,14 @@ class ExpressionEvaluationSuite extends FunSuite { } test("struct casting") { - val struct = Literal( + val struct = Literal.create( Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal( + val struct_notNull = Literal.create( Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), @@ -706,7 +712,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex casting") { - val complex = Literal( + val complex = Literal.create( Row( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), @@ -749,30 +755,30 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c2.isNull, true, row) checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(Literal(1, ShortType).isNull, false) - checkEvaluation(Literal(1, ShortType).isNotNull, true) + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - checkEvaluation(Literal(null, ShortType).isNull, true) - checkEvaluation(Literal(null, ShortType).isNotNull, false) + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) + checkEvaluation(If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), - Literal("a", StringType), Literal("b", StringType)), "b", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) checkEvaluation(c1 in (c1, c2), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -787,9 +793,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) @@ -835,17 +841,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), - Literal(null, StringType)), null, row) + Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), - Literal(null, IntegerType)), null, row) + Literal.create(null, IntegerType)), null, row) def quickBuildGetField(expr: Expression, fieldName: String) = { expr.dataType match { @@ -858,7 +864,7 @@ class ExpressionEvaluationSuite extends FunSuite { def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row) + checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) @@ -868,8 +874,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) - assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true) - assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -884,13 +890,13 @@ class ExpressionEvaluationSuite extends FunSuite { val c4 = 'a.int.at(3) checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -908,12 +914,12 @@ class ExpressionEvaluationSuite extends FunSuite { val c4 = 'a.double.at(3) checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) - checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -934,9 +940,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -948,8 +954,8 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 <=> c4, false, row) checkEvaluation(c4 <=> c6, true, row) checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) - checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { @@ -960,17 +966,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal(null, StringType), null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) checkEvaluation(c1 startsWith "a", true, row) checkEvaluation(c1 startsWith "b", false, row) checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) checkEvaluation(c1 endsWith "c", true, row) checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) } test("Substring") { @@ -979,54 +985,54 @@ class ExpressionEvaluationSuite extends FunSuite { val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), null, new GenericRow(Array[Any](null))) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), null, row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "example", row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), "xample", row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable === false) + assert(Substring(s_notNull, Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1044,7 +1050,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } @@ -1058,22 +1064,22 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 & c2, 0, row) checkEvaluation(c1 | c2, 3, row) @@ -1081,3 +1087,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(~c1, -2, row) } } + +// TODO: Make the tests work with codegen. +class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ef10c0aece716..a0efe9e2e7f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -182,33 +182,33 @@ class ConstantFoldingSuite extends PlanTest { IsNull(Literal(null)) as 'c1, IsNotNull(Literal(null)) as 'c2, - GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem(Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, UnresolvedGetField( - Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), "a") as 'c5, - UnaryMinus(Literal(null, IntegerType)) as 'c6, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, Cast(Literal(null), IntegerType) as 'c7, - Not(Literal(null, BooleanType)) as 'c8, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal(null, IntegerType), 1) as 'c9, - Add(1, Literal(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal(null, StringType), "abc") as 'c13, - Like("abc", Literal(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 ) val optimized = Optimize(originalQuery.analyze) @@ -219,31 +219,31 @@ class ConstantFoldingSuite extends PlanTest { Literal(true) as 'c1, Literal(false) as 'c2, - Literal(null, IntegerType) as 'c3, - Literal(null, IntegerType) as 'c4, - Literal(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as 'c3, + Literal.create(null, IntegerType) as 'c4, + Literal.create(null, IntegerType) as 'c5, - Literal(null, IntegerType) as 'c6, - Literal(null, IntegerType) as 'c7, - Literal(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as 'c6, + Literal.create(null, IntegerType) as 'c7, + Literal.create(null, BooleanType) as 'c8, - Literal(null, IntegerType) as 'c9, - Literal(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as 'c9, + Literal.create(null, IntegerType) as 'c10, - Literal(null, BooleanType) as 'c11, - Literal(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as 'c11, + Literal.create(null, BooleanType) as 'c12, - Literal(null, BooleanType) as 'c13, - Literal(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as 'c13, + Literal.create(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15, + Literal.create(null, StringType) as 'c15, - Literal(null, StringType) as 'c16, - Literal(null, StringType) as 'c17, - Literal(null, StringType) as 'c18, + Literal.create(null, StringType) as 'c16, + Literal.create(null, StringType) as 'c17, + Literal.create(null, StringType) as 'c18, - Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as 'c19, + Literal.create(null, BooleanType) as 'c20 ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index ae99a3f9ba287..2f3704be59a9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -29,7 +29,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) val optimizedPlan = DefaultOptimizer(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 48884040bfce7..129d091ca03e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ /** @@ -55,6 +55,6 @@ class PlanTest extends FunSuite { /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ce92a2160b6..274f3ede0045c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -90,7 +90,7 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5aece166aad22..5c6016a4a2ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -146,7 +146,7 @@ class DataFrame private[sql]( _: WriteToFile => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => - queryExecution.logical + queryExecution.analyzed } /** @@ -237,11 +237,11 @@ class DataFrame private[sql]( def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + - "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + - "New column names: " + colNames.mkString(", ")) + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => - apply(oldName).as(newName) + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) } select(newCols :_*) } @@ -273,7 +273,7 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) /** - * Prints the plans (logical and physical) to the console for debugging purpose. + * Prints the plans (logical and physical) to the console for debugging purposes. * @group basic */ def explain(extended: Boolean): Unit = { @@ -285,7 +285,7 @@ class DataFrame private[sql]( } /** - * Only prints the physical plan to the console for debugging purpose. + * Only prints the physical plan to the console for debugging purposes. * @group basic */ def explain(): Unit = explain(extended = false) @@ -319,6 +319,17 @@ class DataFrame private[sql]( */ def show(): Unit = show(20) + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * df.na.drop() + * }}} + * + * @group dfops + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + /** * Cartesian join with another [[DataFrame]]. * @@ -751,6 +762,67 @@ class DataFrame private[sql]( select(colNames :_*) } + /** + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * df.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) + } + /** * Returns the first `n` rows. * @group action @@ -832,7 +904,8 @@ class DataFrame private[sql]( */ override def repartition(numPartitions: Int): DataFrame = { sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) } /** @@ -880,10 +953,12 @@ class DataFrame private[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * memoized. Once called, it won't change even if you change any query planning related Spark SQL + * configurations (e.g. `spark.sql.shuffle.partitions`). * @group rdd */ - def rdd: RDD[Row] = { + lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) @@ -902,8 +977,8 @@ class DataFrame private[sql]( def javaRDD: JavaRDD[Row] = toJavaRDD /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. * * @group basic */ @@ -1178,7 +1253,7 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. * If you pass `true` for `allowExisting`, it will drop any table with the * given name; if you pass `false`, it will throw if the table already @@ -1202,7 +1277,7 @@ class DataFrame private[sql]( } /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * Assumes the table already exists and has a compatible schema. If you * pass `true` for `overwrite`, it will `TRUNCATE` the table before * performing the `INSERT`s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala new file mode 100644 index 0000000000000..bf3c3fe876873 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -0,0 +1,231 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.{lang => jl} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functionality for working with missing data in [[DataFrame]]s. + */ +@Experimental +final class DataFrameNaFunctions private[sql](df: DataFrame) { + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values. + */ + def drop(): DataFrame = drop("any", df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values. + * + * If `how` is "any", then drop rows containing any null values. + * If `how` is "all", then drop rows only if every column is null for that row. + */ + def drop(how: String): DataFrame = drop(how, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Seq[String]): DataFrame = { + how.toLowerCase match { + case "any" => drop(cols.size, cols) + case "all" => drop(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + */ + def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null + * values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than + * `minNonNulls` non-null values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { + // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + df.filter(Column(predicate)) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + */ + def fill(value: Double): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + */ + def fill(value: String): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[Double](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in specified string columns. + * If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in + * specified string columns. If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[String](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * import com.google.common.collect.ImmutableMap; + * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); + * }}} + */ + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * df.na.fill(Map( + * "A" -> "unknown", + * "B" -> 1.0 + * )) + * }}} + */ + def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + private def fill0(values: Seq[(String, Any)]): DataFrame = { + // Error handling + values.foreach { case (colName, replaceValue) => + // Check column name exists + df.resolve(colName) + + // Check data type + replaceValue match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + // This is good + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") + } + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => + v match { + case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Double => fillCol[Double](f, v) + case v: jl.Long => fillCol[Double](f, v.toDouble) + case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: String => fillCol[String](f, v) + } + }.getOrElse(df.col(f.name)) + } + df.select(projections : _*) + } + + /** + * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + */ + private def fillCol[T](col: StructField, replacement: T): Column = { + coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45a63ae26ed71..a5e6b638d2150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -127,10 +127,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.builder() - * .put("age", "max") - * .put("expense", "sum") - * .build()); + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); * }}} */ def agg(exprs: java.util.Map[String, String]): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e59cf9b9e037b..39dd14e796f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} @@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext) ExtractPythonUdfs :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } @transient @@ -177,7 +181,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental @transient - lazy val emptyDataFrame = DataFrame(this, NoRelation) + lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) /** * A collection of methods for registering user-defined functions (UDF). @@ -388,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } @@ -600,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -629,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -1065,14 +1083,6 @@ class SQLContext(@transient val sparkContext: SparkContext) Batch("Add exchange", Once, AddExchange(self)) :: Nil } - @transient - protected[sql] lazy val checkAnalysis = new CheckAnalysis { - override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) - ) - } - - protected[sql] def openSession(): SQLSession = { detachSession() val session = createSession() @@ -1105,7 +1115,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = checkAnalysis(analyzed) + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = analyzer(logical) lazy val withCachedData: LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 89682d25ca7dc..a8018b9213f2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -93,7 +93,7 @@ case class GeneratedAggregate( } val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal(null, calcType) + val initialValue = Literal.create(null, calcType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -137,13 +137,13 @@ case class GeneratedAggregate( expr.dataType match { case DecimalType.Fixed(_, _) => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Cast(Divide( Cast(currentSum, DecimalType.Unlimited), Cast(currentCount, DecimalType.Unlimited)), a.dataType)) case _ => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) } @@ -156,7 +156,7 @@ case class GeneratedAggregate( case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal(null, expr.dataType) + val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) AggregateEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index c4534fd5f67e4..967bd76b302d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHa private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -57,8 +57,6 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[Decimal]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2b581152e5f77..f754fa770d1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -296,7 +296,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil - case logical.NoRelation => + case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 1704be7fcbd30..0feabc4282f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -49,9 +49,9 @@ private[sql] object DriverQuirks { * Fetch the DriverQuirks class corresponding to a given database url. */ def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { + if (url.startsWith("jdbc:mysql")) { new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { + } else if (url.startsWith("jdbc:postgresql")) { new PostgresQuirks() } else { new NoQuirks() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 2b0358c4e2a1e..0b770f2251943 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -49,7 +49,7 @@ private[sql] object JsonRDD extends Logging { val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = if (schemaData.isEmpty()) { - Set.empty[(String,DataType)] + Set.empty[(String, DataType)] } else { parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5130d8ad5e003..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet import java.io.IOException import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -43,12 +42,13 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** @@ -356,7 +356,7 @@ private[sql] case class InsertIntoParquetTable( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) @@ -512,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index da668f068613b..60e1bec4db8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -390,6 +390,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromAttributes(attributes: Seq[Attribute], toThriftSchemaNames: Boolean = false): MessageType = { + checkSpecialCharacters(attributes) val fields = attributes.map( attribute => fromDataType(attribute.dataType, attribute.name, attribute.nullable, @@ -404,7 +405,20 @@ private[parquet] object ParquetTypesConverter extends Logging { } } + private def checkSpecialCharacters(schema: Seq[Attribute]) = { + // ,;{}()\n\t= and space character are special characters in Parquet schema + schema.map(_.name).foreach { name => + if (name.matches(".*[ ,;{}()\n\t=].*")) { + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + } + } + def convertToString(schema: Seq[Attribute]): String = { + checkSpecialCharacters(schema) StructType.fromAttributes(schema).json } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 410600b0529d3..0dce3623a66df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -42,6 +42,7 @@ import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions @@ -121,7 +122,8 @@ private[sql] class DefaultSource val df = sqlContext.createDataFrame( data.queryExecution.toRdd, - data.schema.asNullable) + data.schema.asNullable, + needsConversion = false) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) @@ -266,7 +268,8 @@ private[sql] case class ParquetRelation2( // containing Parquet files (e.g. partitioned Parquet table). val baseStatuses = paths.distinct.map { p => val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val qualified = fs.makeQualified(new Path(p)) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) if (!fs.exists(qualified) && maybeSchema.isDefined) { fs.mkdirs(qualified) @@ -432,14 +435,22 @@ private[sql] case class ParquetRelation2( FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) } - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + // Try to push down filters when filter push-down is enabled. + if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + } if (isPartitioned) { logInfo { @@ -662,7 +673,8 @@ private[sql] case class ParquetRelation2( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) /* apparently we need a TaskAttemptID to construct an OutputCommitter; @@ -758,12 +770,15 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { // Uses Parquet field names but retains Metastore data types. @@ -774,6 +789,32 @@ private[sql] object ParquetRelation2 extends Logging { }) } + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). // However, we are already using Catalyst expressions for partition pruning and predicate // push-down here... @@ -834,9 +875,9 @@ private[sql] object ParquetRelation2 extends Logging { * PartitionValues( * Seq("a", "b", "c"), * Seq( - * Literal(42, IntegerType), - * Literal("hello", StringType), - * Literal(3.14, FloatType))) + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) * }}} */ private[parquet] def parsePartition( @@ -915,15 +956,16 @@ private[sql] object ParquetRelation2 extends Logging { raw: String, defaultPartitionName: String): Literal = { // First tries integral types - Try(Literal(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types - .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) } } @@ -942,7 +984,7 @@ private[sql] object ParquetRelation2 extends Logging { } literals.map { case l @ Literal(_, dataType) => - Literal(Cast(l, desiredType).eval(), desiredType) + Literal.create(Cast(l, desiredType).eval(), desiredType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 67f3507c61ab6..e13759b7feb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -166,6 +167,15 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) + case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringStartsWith(a.name, v)) + + case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringEndsWith(a.name, v)) + + case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => + Some(sources.StringContains(a.name, v)) + case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba30..dbdb0d39c26a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index eb46b46ca5bf4..319de710fbc3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -204,7 +204,7 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) - def className = clazz.getCanonicalName + def className: String = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 1e4505e36d2f0..791046e0079d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,16 +17,85 @@ package org.apache.spark.sql.sources +/** + * A filter predicate for data sources. + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + */ case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + */ case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + */ case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + */ case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + */ case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a046a48c1733d..8f9946a5a801e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -152,6 +152,9 @@ trait PrunedScan { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c11d0ae5bf1cc..2fdd798b44bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 0000000000000..0896f175c056f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameNaFunctionsSuite extends QueryTest { + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Amy", null, null), + (null, null, null)).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil), + rows(0) :: rows(2) :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all"), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")), + rows(0) :: rows(1) :: rows(2) :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil), + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, null) :: + Row("Amy", 50, null) :: + Row(null, 50, null) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(mapAsJavaMap(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ))), + Row("test", null, 1, 2.2)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c30ed694a62f0..1db0cf7daac03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.sql @@ -60,6 +60,14 @@ class DataFrameSuite extends QueryTest { assert($"test".toString === "test") } + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + test("invalid plan toString, debug mode") { val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") @@ -84,6 +92,11 @@ class DataFrameSuite extends QueryTest { testData.collect().toSeq) } + test("empty data frame") { + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) + } + test("head and take") { assert(testData.take(2) === testData.collect().take(2)) assert(testData.head(2) === testData.collect().take(2)) @@ -113,6 +126,10 @@ class DataFrameSuite extends QueryTest { checkAnswer( df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } test("explode") { @@ -443,6 +460,50 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) @@ -453,4 +514,11 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show() testData.select($"*").show(1000) } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468dad..36465cc2fa11a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -50,4 +53,13 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a3c0076e16d6c..87e7cf8c8af9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1084,10 +1084,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-6145: ORDER BY test for nested fields") { jsonRDD(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") - // These should be successfully analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed - sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4d32e84fc1115..6a2c2a7c4080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -321,6 +321,23 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index adb3c9391f6c2..b7561ce7298cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -45,11 +45,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) } - check("10", Literal(10, IntegerType)) - check("1000000000000000", Literal(1000000000000000L, LongType)) - check("1.5", Literal(1.5, FloatType)) - check("hello", Literal("hello", StringType)) - check(defaultPartitionName, Literal(null, NullType)) + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, FloatType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { @@ -75,22 +75,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "file://path/a=10", PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType)))) check( "file://path/a=10/b=hello/c=1.5", PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( - Literal(10, IntegerType), - Literal("hello", StringType), - Literal(1.5, FloatType)))) + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, FloatType)))) check( "file://path/a=10/b_hello/c=1.5", PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, FloatType)))) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 321832cd43211..61f1cf347ab0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -212,8 +212,11 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count - assert(intercept[Throwable] { + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -221,13 +224,56 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("lowerCase", BinaryType), StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) - // Conflicting field names + // Conflicting non-nullable field names intercept[Throwable] { ParquetRelation2.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType))), + StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 91c6367371f15..33c67355967dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -32,6 +32,10 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { override val extendedResolutionRules = PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index ffeccf0b69394..773bd1602d5e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -35,20 +35,25 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL extends BaseRelation with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters - def translateFilter(filter: Filter): Int => Boolean = filter match { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v @@ -57,13 +62,27 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) case IsNull("a") => (a: Int) => false // Int can't be null case IsNotNull("a") => (a: Int) => true - case Not(pred) => (a: Int) => !translateFilter(pred)(a) - case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a) - case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a) + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) case _ => (a: Int) => true } - def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -93,7 +112,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -128,41 +148,53 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", Seq.empty[Row]) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL", + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", (1 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", - Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", (6 to 10).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) @@ -193,6 +225,15 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a9816f6c38cd2..04440076a26a3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -89,6 +89,20 @@ junit test + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c06c2e396bbc1..7c6a7df2bd01e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -57,6 +57,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + /** * When true, a table created by a Hive CTAS statement (no USING clause) will be * converted to a data source table, using the data source set by spark.sql.sources.default. @@ -172,12 +181,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..921c6194c7b76 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -593,7 +593,7 @@ private[hive] trait HiveInspectors { case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d1a99555e90c6..315fab673da5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.metastore.{TableType, Warehouse} @@ -32,7 +33,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Catalog, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NoSuchTableException, Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -66,10 +67,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = synchronized { + val table = HiveMetastoreCatalog.this.synchronized { client.getTable(in.database, in.name) } - val userSpecifiedSchema = + + def schemaStringFromParts: Option[String] = { Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => val parts = (0 until numParts.toInt).map { index => val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") @@ -81,10 +83,19 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with part } - // Stick all parts back to a single schema string in the JSON representation - // and convert it back to a StructType. - DataType.fromJson(parts.mkString).asInstanceOf[StructType] + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -105,7 +116,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } override def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -172,12 +191,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = try client.getTable(databaseName, tblName) catch { + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => throw new NoSuchTableException } @@ -199,7 +222,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } @@ -211,10 +236,49 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) @@ -227,18 +291,28 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } val partitionSpec = PartitionSpec(partitionSchema, partitions) val paths = partitions.map(_.path) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json), - None, - Some(partitionSpec))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - LogicalRelation( - ParquetRelation2( - paths, - Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } } @@ -459,7 +533,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Write path case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) @@ -470,7 +544,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -479,33 +553,28 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation -> relation.output, parquetRelation, attributedRewrites) + (relation, parquetRelation, attributedRewrites) } - // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and - // their output attributes as the key of the map. This is because MetastoreRelation.equals - // doesn't take output attributes into account, thus multiple MetastoreRelation instances - // pointing to the same table get collapsed into a single entry in the map. A proper fix for - // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) val alias = r.alias.getOrElse(r.tableName) Subquery(alias, parquetRelation) case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) InsertIntoTable(parquetRelation, partition, child, overwrite) case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r -> r.output) => - val parquetRelation = relationMap(r -> r.output) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) InsertIntoTable(parquetRelation, partition, child, overwrite) case other => other.transformExpressions { @@ -697,10 +766,23 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { self: Product => + override def equals(other: scala.Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and @@ -778,6 +860,10 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def newInstance(): MetastoreRelation = { + MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c45c4ad70fae9..077e64133faad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -479,7 +479,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = @@ -622,7 +622,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val relations = fromClause match { case Some(f) => nodeToRelation(f) - case None => NoRelation + case None => OneRowRelation } val withWhere = whereClause.map { whereNode => @@ -659,7 +659,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) = clause match { + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) @@ -1201,7 +1202,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C CreateArray(children.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) @@ -1213,9 +1214,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1226,21 +1227,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index da53d30354551..6c96747439683 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -50,7 +50,7 @@ case class InsertIntoHiveTable( @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -72,7 +72,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, @@ -200,38 +199,45 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 4345ffbf30f77..99dc58646ddd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -58,12 +58,13 @@ case class DropTable( try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}") + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ba2bf67aed684..8398da268174d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -118,19 +117,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -213,7 +200,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) - val colString = + val colString = if (string == null || string.isEmpty) { defaultPartName } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index 0270e63557963..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConversions._ - -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util._ - - -/** - * *** DUPLICATED FROM sql/core. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class QueryTest extends PlanTest { - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param rdd the [[DataFrame]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") - } - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(rdd, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } -} - -object QueryTest { - /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - */ - def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case o => o - }) - } - if (!isSorted) converted.sortBy(_.toString) else converted - } - val sparkAnswer = try rdd.collect().toSeq catch { - case e: Exception => - val errorMessage = - s""" - |Exception thrown while executing query: - |${rdd.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin - return Some(errorMessage) - } - - return None - } - - def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(rdd, expectedAnswer.toSeq) match { - case Some(errorMessage) => errorMessage - case None => null - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 98f1c0e69e29d..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.scalatest.FunSuite - -/** - * *** DUPLICATED FROM sql/catalyst/plans. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class PlanTest extends FunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) - case a: Alias => - Alias(a.child, a.name)(exprId = ExprId(0)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 221a0c263d36c..c188264072a84 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -24,21 +24,6 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3181cfe40016c..c482c6de8a736 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -79,9 +79,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1,2.0d,3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -166,7 +166,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -212,8 +212,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -222,7 +222,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada52642..fa8e11ffec2b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT @@ -36,4 +37,11 @@ class HiveMetastoreCatalogSuite extends FunSuite { assert(HiveMetastoreTypes.toMetastoreType(udt) === HiveMetastoreTypes.toMetastoreType(udt.sqlType)) } + + test("duplicated metastore relations") { + import TestHive.implicits._ + val df = TestHive.sql("SELECT * FROM src") + println(df.queryExecution) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index aa6fb42de7f88..8011952e0d535 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -198,7 +198,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { val testDatawithNull = TestHive.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e5ad0bf552073..e09c702c8969e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,6 +25,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ @@ -682,6 +684,27 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { assert(schema === actualSchema) } + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1e05a024b8807..ccd0e5aa51f95 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,7 +120,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8f3285242091c..a5ec312ee430c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,7 +138,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -299,7 +299,7 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.logical) + hiveQueries.foreach(_.analyzed) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28a..bdb53ddf59c19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1187228f4c3db..817b9dcb8f505 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -41,8 +41,32 @@ case class NestedArray1(a: NestedArray2) */ class SQLQuerySuite extends QueryTest { + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } + test("explode nested Field") { - Seq(NestedArray1(NestedArray2(Seq(1,2,3)))).toDF.registerTempTable("nestedArray") + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") checkAnswer( sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) @@ -433,4 +457,26 @@ class SQLQuerySuite extends QueryTest { dropTempTable("data") setConf("spark.sql.hive.convertCTAS", originalConf) } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 432d65a874518..5f71e1bbc2d2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -26,8 +26,10 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.SaveMode @@ -390,6 +392,114 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { @@ -578,6 +688,22 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { sql("DROP TABLE alwaysNullable") } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.saveAsParquetFile(filePath2) + val df4 = parquetFile(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } } class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { diff --git a/streaming/pom.xml b/streaming/pom.xml index 23a8358d45c2a..5ca55a4f680bb 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -97,34 +97,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - org.apache.maven.plugins maven-shade-plugin diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f73b463d07779..28703ef8129b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -234,7 +234,7 @@ object CheckpointReader extends Logging { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! - def fs = checkpointPath.getFileSystem(hadoopConf) + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 73030e15c5661..808dcc174cf9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -169,7 +169,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -179,7 +179,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -190,7 +190,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -201,7 +203,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index f94f2d0e8bd31..93baad19e3ee1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -526,7 +526,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index e3db01c1e12c6..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -192,7 +192,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -313,7 +313,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } @@ -344,7 +344,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } @@ -625,7 +625,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -633,7 +633,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 795c5aa6d585b..24f99a2b929f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -839,7 +839,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cf191715d29d6..87bc20f79c3cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 91a2b2bba461d..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { @@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 26435d8515815..0c4c06534a693 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging { val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) + override def batchDuration: Duration = Milliseconds(1000) - override def useManualClock = false + override def useManualClock: Boolean = false override def afterFunction() { Utils.deleteRecursively(directory) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7ed6320a3d0bc..e6ac4975c5e68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time @@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f92094d..c090eaec2928d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } @@ -99,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -123,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 42fad769f0c1a..b63b37d9f9cef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index aa20ad0b5374e..10c35cba8dc53 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 2e5005ef6ff14..d1bbf39dc7897 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -213,7 +213,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 SlowTestReceiver.receivedAllRecords = false - //Create test receiver that sleeps in onStop() + // Create test receiver that sleeps in onStop() val totalNumRecords = 15 val recordsPerSecond = 1 val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) @@ -370,7 +370,8 @@ object TestReceiver { } /** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ -class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { var receivingThreadOption: Option[Thread] = None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f73..852e8bb71d4f6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,8 +38,8 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3565d621e8a6c..c3cae8aeb6d15 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) ssc.awaitTerminationOrTimeout(50) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 87a0395efbf2a..998426ebb82e5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark._ /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16c..c39ad05f41520 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577dd..c3602a5b73732 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 4150b60635ed6..7865b06c2e3c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase { val receiverTracker = ssc.scheduler.receiverTracker // Get the blocks belonging to a batch - def getBlocksOfBatch(batchTime: Long) = { + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8335659667f22..a3919c43b95b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -291,7 +291,7 @@ object WriteAheadLogSuite { manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74d..d66750463033a 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d3..583823c90c5c6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d3..f2d135397ce2f 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -58,8 +58,8 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { @@ -78,7 +78,7 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d18690cd9cbf..24a1e02795218 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -162,7 +160,7 @@ private[spark] class ApplicationMaster( * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { + final def getDefaultFinalStatus(): FinalApplicationStatus = { if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { @@ -175,31 +173,35 @@ private[spark] class ApplicationMaster( * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = Utils.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } } } } @@ -236,22 +238,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -268,8 +269,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -279,8 +280,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -427,7 +427,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -439,7 +439,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -501,44 +501,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart() = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + } + override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -546,7 +531,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } @@ -567,7 +561,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -576,11 +570,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -592,7 +586,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 61f8fc3f5a014..79d55a09eb671 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -66,6 +66,8 @@ private[spark] class Client( private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) def stop(): Unit = yarnClient.stop() @@ -564,31 +566,13 @@ private[spark] class Client( if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } @@ -609,24 +593,57 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c1d3f7320f53c..1ce10d906ab23 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -59,15 +59,15 @@ class ExecutorRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) lazy val env = prepareEnvironment(container) - def run = { + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b58f..b8f42dadcb464 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 0e37276ba724b..c06c0105670c0 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -143,7 +143,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } - test("run Python application in yarn-cluster mode") { + // Enable this once fix SPARK-6700 + ignore("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py")