diff --git a/LICENSE b/LICENSE
index f1732fb47afc0..3c667bf45059a 100644
--- a/LICENSE
+++ b/LICENSE
@@ -754,7 +754,7 @@ SUCH DAMAGE.
========================================================================
-For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java):
+For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java):
========================================================================
Copyright (C) 2008 The Android Open Source Project
@@ -771,6 +771,25 @@ See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For LimitedInputStream
+ (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java):
+========================================================================
+Copyright (C) 2007 The Guava Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index c5936b5038ac9..badd85ed48c82 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -39,6 +39,8 @@ $(function() {
var column = "table ." + $(this).attr("name");
$(column).hide();
});
+ // Stripe table rows after rows have been hidden to ensure correct striping.
+ stripeTables();
$("input:checkbox").click(function() {
var column = "table ." + $(this).attr("name");
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
index 32187ba6e8df0..6bb03015abb51 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/table.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -28,8 +28,3 @@ function stripeTables() {
});
});
}
-
-/* Stripe all tables after pages finish loading. */
-$(function() {
- stripeTables();
-});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index a2220e761ac98..db57712c83503 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -120,6 +120,20 @@ pre {
border: none;
}
+.stacktrace-details {
+ max-height: 300px;
+ overflow-y: auto;
+ margin: 0;
+ transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+ max-height: 0;
+ padding-top: 0;
+ padding-bottom: 0;
+ border: none;
+}
+
span.expand-additional-metrics {
cursor: pointer;
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index c11f1db0064fd..ef93009a074e7 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -66,7 +66,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Lower and upper bounds on the number of executors. These are required.
private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
- verifyBounds()
// How long there must be backlogged tasks for before an addition is triggered
private val schedulerBacklogTimeout = conf.getLong(
@@ -77,9 +76,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
"spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
// How long an executor must be idle for before it is removed
- private val removeThresholdSeconds = conf.getLong(
+ private val executorIdleTimeout = conf.getLong(
"spark.dynamicAllocation.executorIdleTimeout", 600)
+ // During testing, the methods to actually kill and add executors are mocked out
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ validateSettings()
+
// Number of executors to add in the next round
private var numExecutorsToAdd = 1
@@ -103,17 +107,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Polling loop interval (ms)
private val intervalMillis: Long = 100
- // Whether we are testing this class. This should only be used internally.
- private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
-
// Clock used to schedule when executors should be added and removed
private var clock: Clock = new RealClock
/**
- * Verify that the lower and upper bounds on the number of executors are valid.
+ * Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
*/
- private def verifyBounds(): Unit = {
+ private def validateSettings(): Unit = {
if (minNumExecutors < 0 || maxNumExecutors < 0) {
throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
}
@@ -124,6 +125,22 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
}
+ if (schedulerBacklogTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!")
+ }
+ if (sustainedSchedulerBacklogTimeout <= 0) {
+ throw new SparkException(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!")
+ }
+ if (executorIdleTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!")
+ }
+ // Require external shuffle service for dynamic allocation
+ // Otherwise, we may lose shuffle files when killing executors
+ if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) {
+ throw new SparkException("Dynamic allocation of executors requires the external " +
+ "shuffle service. You may enable this through spark.shuffle.service.enabled.")
+ }
}
/**
@@ -254,7 +271,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
if (removeRequestAcknowledged) {
logInfo(s"Removing executor $executorId because it has been idle for " +
- s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})")
+ s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
executorsPendingToRemove.add(executorId)
true
} else {
@@ -329,8 +346,8 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
private def onExecutorIdle(executorId: String): Unit = synchronized {
if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)")
- removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
}
}
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b2377e..dbff9d12b5ad7 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and waits for the response from the server and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ // Default SecurityManager only has a single secret key, so ignore appId.
+ override def getSaslUser(appId: String): String = getSaslUser()
+ override def getSecretKey(appId: String): String = getSecretKey()
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a9017afead..4c6c86c7bad78 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
def contains(key: String): Boolean = settings.contains(key)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8b4db783979ec..03ea672c813d1 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,9 +21,8 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
@@ -41,6 +40,7 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.executor.TriggerThreadDump
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
@@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
import org.apache.spark.ui.SparkUI
import org.apache.spark.ui.jobs.JobProgressListener
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
@@ -361,6 +363,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -535,6 +560,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
/**
+ * :: Experimental ::
+ *
* Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
* (useful for binary data)
*
@@ -577,6 +604,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
}
/**
+ * :: Experimental ::
+ *
* Load data from a flat binary file, assuming the length of each record is constant.
*
* @param path Directory to the input data files
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e2f13accdfab5..e7454beddbfd0 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -276,7 +276,7 @@ object SparkEnv extends Logging {
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
- new NettyBlockTransferService(conf)
+ new NettyBlockTransferService(conf, securityManager)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
@@ -285,8 +285,9 @@ object SparkEnv extends Logging {
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc0c31fa..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb364661f..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), UTF_8)
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), UTF_8).toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 202fba699ab26..af5fd8e0ac00c 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -69,11 +69,13 @@ case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
- reduceId: Int)
+ reduceId: Int,
+ message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
+ s"message=\n$message\n)"
}
}
@@ -81,15 +83,48 @@ case class FetchFailed(
* :: DeveloperApi ::
* Task failed due to a runtime exception. This is the most common failure case and also captures
* user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
*/
@DeveloperApi
case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
+ fullStackTrace: String,
metrics: Option[TaskMetrics])
extends TaskFailedReason {
- override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ }
+
+ override def toErrorString: String =
+ if (fullStackTrace == null) {
+ // fullStackTrace is added in 1.2.0
+ // If fullStackTrace is null, use the old error string for backward compatibility
+ exceptionString(className, description, stackTrace)
+ } else {
+ fullStackTrace
+ }
+
+ /**
+ * Return a nice string representation of the exception, including the stack trace.
+ * Note: It does not include the exception's causes, and is only used for backward compatibility.
+ */
+ private def exceptionString(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement]): String = {
+ val desc = if (description == null) "" else description
+ val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
+ s"$className: $desc\n$st"
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index e3aeba7e6c39d..5c6e8d32c5c8a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,11 +21,6 @@ import java.io.Closeable
import java.util
import java.util.{Map => JMap}
-import java.io.DataInputStream
-
-import org.apache.hadoop.io.{BytesWritable, LongWritable}
-import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat}
-
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -33,6 +28,7 @@ import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.input.PortableDataStream
import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
@@ -286,6 +282,8 @@ class JavaSparkContext(val sc: SparkContext)
new JavaPairRDD(sc.binaryFiles(path, minPartitions))
/**
+ * :: Experimental ::
+ *
* Read a directory of binary files from HDFS, a local file system (available on all nodes),
* or any Hadoop-supported file system URI as a byte array. Each file is read as a single
* record and returned in a key-value pair, where the key is the path of each file,
@@ -312,15 +310,19 @@ class JavaSparkContext(val sc: SparkContext)
*
* @note Small files are preferred; very large files but may cause bad performance.
*/
+ @Experimental
def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
/**
+ * :: Experimental ::
+ *
* Load data from a flat binary file, assuming the length of each record is constant.
*
* @param path Directory to the input data files
* @return An RDD of data with values, represented as byte arrays
*/
+ @Experimental
def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.binaryRecords(path, recordLength))
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 61b125ef7c6c1..45beb8fc8c925 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -21,13 +21,13 @@ import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import org.apache.spark.input.PortableDataStream
+
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -397,22 +397,33 @@ private[spark] object PythonRDD extends Logging {
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ case stream: PortableDataStream =>
+ newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, stream: PortableDataStream) =>
+ newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
+ case (key, stream) =>
+ writeUTF(key, dataOut)
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, value: String) =>
+ newIter.asInstanceOf[Iterator[(String, String)]].foreach {
+ case (key, value) =>
+ writeUTF(key, dataOut)
+ writeUTF(value, dataOut)
+ }
+ case (key: Array[Byte], value: Array[Byte]) =>
+ newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
+ case (key, value) =>
+ dataOut.writeInt(key.length)
+ dataOut.write(key)
+ dataOut.writeInt(value.length)
+ dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
@@ -442,7 +453,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -468,7 +479,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -494,7 +505,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -537,7 +548,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -563,7 +574,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -746,104 +757,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- SerDeUtil.initialize()
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- SerDeUtil.initialize()
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index ebdc3533e0992..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
+
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
@@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- initialize()
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index e9ca9166eb4d6..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..88118e2837741
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf)
+ private val blockHandler = new ExternalShuffleBlockHandler()
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index f1f66d0903f1c..ca262de832e25 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -111,6 +111,9 @@ private[spark] class Worker(
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
+
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
@@ -154,6 +157,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
@@ -419,6 +423,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -441,7 +446,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 697154d762d41..3711824a40cfc 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e24a15f015e1c..caf4d76713d49 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.ActorSystem
+import akka.actor.{Props, ActorSystem}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -86,12 +86,17 @@ private[spark] class Executor(
conf, executorId, slaveHostname, port, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -131,6 +136,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -155,7 +161,7 @@ private[spark] class Executor(
}
override def run() {
- val startTime = System.currentTimeMillis()
+ val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
@@ -200,7 +206,7 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - startTime
+ m.executorDeserializeTime = taskStart - deserializeStartTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
m.resultSerializationTime = afterSerialization - beforeSerialization
@@ -257,7 +263,7 @@ private[spark] class Executor(
m.executorRunTime = serviceTime
m.jvmGCTime = gcTime - startGCTime
}
- val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+ val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
index 5164a74bec4e9..36a1e5d475f46 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -115,7 +115,7 @@ private[spark] class FixedLengthBinaryRecordReader
if (currentPosition < splitEnd) {
// setup a buffer to store the record
val buffer = recordValue.getBytes
- fileInputStream.read(buffer, 0, recordLength)
+ fileInputStream.readFully(buffer)
// update our current position
currentPosition = currentPosition + recordLength
// return true
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1c4327cf13b51..b937ea825f49e 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,15 +17,17 @@
package org.apache.spark.network.netty
+import scala.collection.JavaConversions._
import scala.concurrent.{Future, Promise}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -33,18 +35,30 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+ extends BlockTransferService {
+
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
- val serializer = new JavaSerializer(conf)
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
- transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
- clientFactory = transportContext.createClientFactory()
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer()
logInfo("Server created on " + server.getPort)
}
@@ -57,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
- val client = clientFactory.createClient(host, port)
- new OneForOneBlockFetcher(client, blockIds.toArray, listener)
- .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, blockIds.toArray, listener)
+ .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e235811d..c2d9578be7ebb 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75bb4d65..f198aa8564a54 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -34,6 +34,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -600,7 +601,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +635,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +779,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 946fb5616d3ec..a157e36e2286e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -211,20 +211,11 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
@@ -234,6 +225,18 @@ class HadoopRDD[K, V](
if (bytesReadCallback.isDefined) {
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+
+ var reader: RecordReader[K, V] = null
+ val inputFormat = getInputFormat(jobConf)
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
var recordsSinceMetricsUpdate = 0
override def getNext() = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 6d6b86721ca74..351e145f96f9a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -107,20 +107,10 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- format match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
@@ -131,6 +121,18 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+ val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ format match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index af17b5d5d2571..22449517d100f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1053,7 +1053,7 @@ class DAGScheduler(
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingTasks += task
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
@@ -1063,7 +1063,7 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure"))
+ markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}
@@ -1094,7 +1094,7 @@ class DAGScheduler(
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
- case ExceptionFailure(className, description, stackTrace, metrics) =>
+ case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 54904bffdf10b..4e3d9de540783 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -215,7 +215,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + taskEnd.stageId
stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d8c0e2f66df01..5289661eb896b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Build a Mesos resource protobuf object */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 8e2faff90f9b2..c5f3493477bc5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -278,8 +278,7 @@ private[spark] class MesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Turn a Spark TaskDescription into a Mesos task */
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 71c08e9d5a8c3..be184464e0ae9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{FetchFailed, TaskEndReason}
+import org.apache.spark.util.Utils
/**
* Failed to fetch a shuffle block. The executor catches this exception and propagates it
@@ -30,13 +31,22 @@ private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
- reduceId: Int)
- extends Exception {
+ reduceId: Int,
+ message: String,
+ cause: Throwable = null)
+ extends Exception(message, cause) {
- override def getMessage: String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ def this(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ cause: Throwable) {
+ this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ }
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ Utils.exceptionString(this))
}
/**
@@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId) {
-
- override def getMessage: String = message
-}
+ extends FetchFailedException(null, shuffleId, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index f49917b7fe833..e3e7434df45b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
@@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Some(block) => {
+ case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
- case None => {
+ case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block")
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5f5dd0dc1c63f..e48d7772d6ee9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -40,7 +40,6 @@ import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
-import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -57,6 +56,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -66,11 +71,10 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService)
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -92,7 +96,12 @@ private[spark] class BlockManager(
private[spark]
val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
- private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337)
+
+ // Port used by the external shuffle service. In Yarn mode, this may be already be
+ // set through the Hadoop configuration as the server is launched in the Yarn NM.
+ private val externalShuffleServicePort =
+ Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+
// Check that we're not using external shuffle service with consolidated shuffle files.
if (externalShuffleServiceEnabled
&& conf.getBoolean("spark.shuffle.consolidateFiles", false)
@@ -102,22 +111,17 @@ private[spark] class BlockManager(
+ " switch to sort-based shuffle.")
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ var blockManagerId: BlockManagerId = _
// Address of the server that serves this executor's shuffle files. This is either an external
// service, or just our own Executor's BlockManager.
- private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
- BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
- } else {
- blockManagerId
- }
+ private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTranserService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val appId = conf.get("spark.app.id", "unknown-app-id")
- new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager,
+ securityManager.isAuthenticationEnabled())
} else {
blockTransferService
}
@@ -150,8 +154,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -170,16 +172,34 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService) = {
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, mapOutputTracker, shuffleManager, blockTransferService)
+ conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager)
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
// Register Executors' configuration with the local shuffle service, if one should exist.
@@ -206,7 +226,6 @@ private[spark] class BlockManager(
return
} catch {
case e: Exception if i < MAX_ATTEMPTS =>
- val attemptsRemaining =
logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
Thread.sleep(SLEEP_TIME_SECS * 1000)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 5e375a2553979..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 291ddfcc113ac..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index ee89c7e521f4e..6b1f57a069431 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.util.{Failure, Success, Try}
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.BlockTransferService
@@ -55,7 +56,7 @@ final class ShuffleBlockFetcherIterator(
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -91,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
- private[this] var currentResult: FetchResult = null
+ @volatile private[this] var currentResult: FetchResult = null
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -118,16 +119,18 @@ final class ShuffleBlockFetcherIterator(
private[this] def cleanup() {
isZombie = true
// Release the current buffer if necessary
- if (currentResult != null && !currentResult.failed) {
- currentResult.buf.release()
+ currentResult match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
}
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
val result = iter.next()
- if (!result.failed) {
- result.buf.release()
+ result match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
}
}
}
@@ -151,7 +154,7 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
shuffleMetrics.remoteBytesRead += buf.size
shuffleMetrics.remoteBlocksFetched += 1
}
@@ -160,7 +163,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FetchResult(BlockId(blockId), -1, null))
+ results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
@@ -231,12 +234,12 @@ final class ShuffleBlockFetcherIterator(
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
buf.retain()
- results.put(new FetchResult(blockId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(blockId, -1, null))
+ results.put(new FailureFetchResult(blockId, e))
return
}
}
@@ -267,15 +270,17 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Try[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (!result.failed) {
- bytesInFlight -= result.size
+
+ result match {
+ case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case _ =>
}
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
@@ -283,20 +288,21 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorOpt: Option[Iterator[Any]] = if (result.failed) {
- None
- } else {
- val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
- val iter = serializer.newInstance().deserializeStream(is).asIterator
- Some(CompletionIterator[Any, Iterator[Any]](iter, {
- // Once the iterator is exhausted, release the buffer and set currentResult to null
- // so we don't release it again in cleanup.
- currentResult = null
- result.buf.release()
- }))
+ val iteratorTry: Try[Iterator[Any]] = result match {
+ case FailureFetchResult(_, e) => Failure(e)
+ case SuccessFetchResult(blockId, _, buf) => {
+ val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Success(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ }))
+ }
}
- (result.blockId, iteratorOpt)
+ (result.blockId, iteratorTry)
}
}
@@ -315,14 +321,30 @@ object ShuffleBlockFetcherIterator {
}
/**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes. -1 if failure is present.
- * @param buf [[ManagedBuffer]] for the content. null is error.
+ * Note that this is NOT the exact bytes.
+ * @param buf [[ManagedBuffer]] for the content.
*/
- case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) {
- def failed: Boolean = size == -1
- if (failed) assert(buf == null) else assert(buf != null)
+ private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
}
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index f02904df31fcf..51dc08f668a43 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,6 +24,9 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
+ val TASK_DESERIALIZATION_TIME =
+ """Time spent deserializating the task closure on the executor."""
+
val INPUT = "Bytes read from Hadoop or from Spark storage."
val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..e9c755e36f716
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
+ } else {
+ Seq.empty
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 9e0e71a51a408..ba97630f025c1 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
val listener = parent.executorsListener
+ val sc = parent.sc
+ val threadDumpEnabled =
+ sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true)
- attachPage(new ExecutorsPage(this))
+ attachPage(new ExecutorsPage(this, threadDumpEnabled))
+ if (threadDumpEnabled) {
+ attachPage(new ExecutorThreadDumpPage(this))
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index b5207360510dd..e3223403c17f4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val failedStages = ListBuffer[StageInfo]()
val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
val stageIdToInfo = new HashMap[StageId, StageInfo]
+
+ // Number of completed and failed stages, may not actually equal to completedStages.size and
+ // failedStages.size respectively due to completedStage and failedStages only maintain the latest
+ // part of the stages, the earlier ones will be removed when there are too many stages for
+ // memory sake.
+ var numCompletedStages = 0
+ var numFailedStages = 0
// Map from pool name to a hash map (map from stage id to StageInfo).
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
@@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
+ numCompletedStages += 1
trimIfNecessary(completedStages)
} else {
failedStages += stage
+ numFailedStages += 1
trimIfNecessary(failedStages)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
index 6e718eecdd52a..83a7898071c9b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
@@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
listener.synchronized {
val activeStages = listener.activeStages.values.toSeq
val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
val now = System.currentTimeMillis
val activeStagesTable =
@@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
+ basicColumns ++ failureReasonHtml
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 23d672cabda07..eb371bd0ea7ed 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -24,6 +24,7 @@ package org.apache.spark.ui.jobs
private object TaskDetailsClassNames {
val SCHEDULER_DELAY = "scheduler_delay"
val GC_TIME = "gc_time"
+ val TASK_DESERIALIZATION_TIME = "deserialization_time"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 79e398eb8c104..10010bdfa1a51 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging {
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
}
+
+ def makeExecutorRef(
+ name: String,
+ conf: SparkConf,
+ host: String,
+ port: Int,
+ actorSystem: ActorSystem): ActorRef = {
+ val executorActorSystemName = SparkEnv.executorActorSystemName
+ Utils.checkHost(host, "Expected hostname")
+ val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name"
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ logInfo(s"Connecting to $name: $url")
+ Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index b6a099825f01b..390310243ee0a 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -25,10 +25,13 @@ private[spark]
// scalastyle:off
abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] {
// scalastyle:on
+
+ private[this] var completed = false
def next() = sub.next()
def hasNext = {
val r = sub.hasNext
- if (!r) {
+ if (!r && !completed) {
+ completed = true
completion()
}
r
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 43c7fba06694a..f15d0c856663f 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -279,13 +279,15 @@ private[spark] object JsonProtocol {
("Block Manager Address" -> blockManagerAddress) ~
("Shuffle ID" -> fetchFailed.shuffleId) ~
("Map ID" -> fetchFailed.mapId) ~
- ("Reduce ID" -> fetchFailed.reduceId)
+ ("Reduce ID" -> fetchFailed.reduceId) ~
+ ("Message" -> fetchFailed.message)
case exceptionFailure: ExceptionFailure =>
val stackTrace = stackTraceToJson(exceptionFailure.stackTrace)
val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing)
("Class Name" -> exceptionFailure.className) ~
("Description" -> exceptionFailure.description) ~
("Stack Trace" -> stackTrace) ~
+ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~
("Metrics" -> metrics)
case ExecutorLostFailure(executorId) =>
("Executor ID" -> executorId)
@@ -629,13 +631,17 @@ private[spark] object JsonProtocol {
val shuffleId = (json \ "Shuffle ID").extract[Int]
val mapId = (json \ "Map ID").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
- new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId)
+ val message = Utils.jsonOption(json \ "Message").map(_.extract[String])
+ new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId,
+ message.getOrElse("Unknown reason"))
case `exceptionFailure` =>
val className = (json \ "Class Name").extract[String]
val description = (json \ "Description").extract[String]
val stackTrace = stackTraceFromJson(json \ "Stack Trace")
+ val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
+ map(_.extract[String]).orNull
val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
- new ExceptionFailure(className, description, stackTrace, metrics)
+ ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
case `executorLostFailure` =>
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
new file mode 100644
index 0000000000000..d4e0ad93b966a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Used for shipping per-thread stacktraces from the executors to driver.
+ */
+private[spark] case class ThreadStackTrace(
+ threadId: Long,
+ threadName: String,
+ threadState: Thread.State,
+ stackTrace: String)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index b402c5f334bb0..a14d6125484fe 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io._
+import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
import java.util.jar.Attributes.Name
@@ -44,6 +45,7 @@ import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
@@ -753,6 +755,7 @@ private[spark] object Utils extends Logging {
/**
* Delete a file or directory and its contents recursively.
* Don't follow directories if they are symlinks.
+ * Throws an exception if deletion is unsuccessful.
*/
def deleteRecursively(file: File) {
if (file != null) {
@@ -1596,19 +1599,31 @@ private[spark] object Utils extends Logging {
.orNull
}
- /** Return a nice string representation of the exception, including the stack trace. */
- def exceptionString(e: Exception): String = {
- if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace)
+ /**
+ * Return a nice string representation of the exception. It will call "printStackTrace" to
+ * recursively generate the stack trace including the exception and its causes.
+ */
+ def exceptionString(e: Throwable): String = {
+ if (e == null) {
+ ""
+ } else {
+ // Use e.printStackTrace here because e.getStackTrace doesn't include the cause
+ val stringWriter = new StringWriter()
+ e.printStackTrace(new PrintWriter(stringWriter))
+ stringWriter.toString
+ }
}
- /** Return a nice string representation of the exception, including the stack trace. */
- def exceptionString(
- className: String,
- description: String,
- stackTrace: Array[StackTraceElement]): String = {
- val desc = if (description == null) "" else description
- val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
- s"$className: $desc\n$st"
+ /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
+ def getThreadDump(): Array[ThreadStackTrace] = {
+ // We need to filter out null values here because dumpAllThreads() may return null array
+ // elements for threads that are dead / don't exist.
+ val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
+ threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
+ val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
+ ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
+ threadInfo.getThreadState, stackTrace)
+ }
}
/**
@@ -1767,6 +1782,21 @@ private[spark] object Utils extends Logging {
val manifest = new JarManifest(manifestUrl.openStream())
manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION)
}.getOrElse("Unknown")
+
+ /**
+ * Return the value of a config either through the SparkConf or the Hadoop configuration
+ * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf
+ * if the key is not set in the Hadoop configuration.
+ */
+ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = {
+ val sparkValue = conf.get(key, default)
+ if (SparkHadoopUtil.get.isYarnMode) {
+ SparkHadoopUtil.get.newConfiguration(conf).get(key, sparkValue)
+ } else {
+ sparkValue
+ }
+ }
+
}
/**
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index 792b9cd8b6ff2..6608ed1e57b38 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -63,8 +63,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
rdd.count()
rdd.count()
- // Invalidate the registered executors, disallowing access to their shuffle blocks.
- rpcHandler.clearRegisteredExecutors()
+ // Invalidate the registered executors, disallowing access to their shuffle blocks (without
+ // deleting the actual shuffle files, so we could access them without the shuffle service).
+ rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */)
// Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry"
// being set.
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala
index 33bd1afea2470..48c386ba04311 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import java.io.{FileWriter, PrintWriter, File}
class InputMetricsSuite extends FunSuite with SharedSparkContext {
- test("input metrics when reading text file") {
+ test("input metrics when reading text file with single split") {
val file = new File(getClass.getSimpleName + ".txt")
val pw = new PrintWriter(new FileWriter(file))
pw.println("some stuff")
@@ -48,6 +48,29 @@ class InputMetricsSuite extends FunSuite with SharedSparkContext {
// Wait for task end events to come in
sc.listenerBus.waitUntilEmpty(500)
assert(taskBytesRead.length == 2)
- assert(taskBytesRead.sum == file.length())
+ assert(taskBytesRead.sum >= file.length())
+ }
+
+ test("input metrics when reading text file with multiple splits") {
+ val file = new File(getClass.getSimpleName + ".txt")
+ val pw = new PrintWriter(new FileWriter(file))
+ for (i <- 0 until 10000) {
+ pw.println("some stuff")
+ }
+ pw.close()
+ file.deleteOnExit()
+
+ val taskBytesRead = new ArrayBuffer[Long]()
+ sc.addSparkListener(new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
+ }
+ })
+ sc.textFile("file://" + file.getAbsolutePath, 2).count()
+
+ // Wait for task end events to come in
+ sc.listenerBus.waitUntilEmpty(500)
+ assert(taskBytesRead.length == 2)
+ assert(taskBytesRead.sum >= file.length())
}
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
new file mode 100644
index 0000000000000..9162ec9801663
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio._
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.util.{Failure, Success, Try}
+
+import org.apache.commons.io.IOUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.{BlockDataManager, BlockTransferService}
+import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}
+
+class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
+ test("security default off") {
+ testConnection(new SparkConf, new SparkConf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ testConnection(conf, conf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on mismatch password") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Mismatched response")
+ }
+ }
+
+ test("security mismatch auth off on server") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "false")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
+ }
+ }
+
+ test("security mismatch auth off on client") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "false")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "true")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Expected SaslMessage")
+ }
+ }
+
+ /**
+ * Creates two servers with different configurations and sees if they can talk.
+ * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
+ * properly. We will throw an out-of-band exception if something other than that goes wrong.
+ */
+ private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
+ val blockManager = mock[BlockDataManager]
+ val blockId = ShuffleBlockId(0, 1, 2)
+ val blockString = "Hello, world!"
+ val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes))
+ when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
+
+ val securityManager0 = new SecurityManager(conf0)
+ val exec0 = new NettyBlockTransferService(conf0, securityManager0)
+ exec0.init(blockManager)
+
+ val securityManager1 = new SecurityManager(conf1)
+ val exec1 = new NettyBlockTransferService(conf1, securityManager1)
+ exec1.init(blockManager)
+
+ val result = fetchBlock(exec0, exec1, "1", blockId) match {
+ case Success(buf) =>
+ IOUtils.toString(buf.createInputStream()) should equal(blockString)
+ buf.release()
+ Success()
+ case Failure(t) =>
+ Failure(t)
+ }
+ exec0.close()
+ exec1.close()
+ result
+ }
+
+ /** Synchronously fetches a single block, acting as the given executor fetching from another. */
+ private def fetchBlock(
+ self: BlockTransferService,
+ from: BlockTransferService,
+ execId: String,
+ blockId: BlockId): Try[ManagedBuffer] = {
+
+ val promise = Promise[ManagedBuffer]()
+
+ self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ promise.failure(exception)
+ }
+
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ promise.success(data.retain())
+ }
+ })
+
+ Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
+ promise.future.value.get
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index b70734dfe37cf..716f875d30b8a 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
+ conf.set("spark.app.id", "app-id")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
var numReceivedMessages = 0
@@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite {
test("security mismatch password") {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
+ conf.set("spark.app.id", "app-id")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
@@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite {
None
})
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
+ val badconf = conf.clone.set("spark.authenticate.secret", "bad")
val badsecurityManager = new SecurityManager(badconf)
val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
var numReceivedServerMessages = 0
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index a2e4f712db55b..819f95634bcdc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -431,7 +431,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// the 2nd ResultTask failed
complete(taskSets(1), Seq(
(Success, 42),
- (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
// this will get called
// blockManagerMaster.removeExecutor("exec-hostA")
// ask the scheduler to try it again
@@ -461,7 +461,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
null,
@@ -472,7 +472,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
null,
@@ -624,7 +624,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// have DAGScheduler try again
@@ -655,7 +655,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
(Success, makeMapStatus("hostB", 1))))
// pretend stage 0 failed because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c6d7105592096..f63e772bf1e59 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -62,7 +62,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
- mapOutputTracker, shuffleManager, transfer)
+ mapOutputTracker, shuffleManager, transfer, securityMgr)
+ store.initialize("app-id")
allStores += store
store
}
@@ -262,7 +263,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
- 10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+ 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr)
+ failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 715b740b857b2..9529502bc8e10 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
- mapOutputTracker, shuffleManager, transfer)
+ val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer, securityMgr)
+ manager.initialize("app-id")
+ manager
}
before {
@@ -793,7 +795,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
// Use Java serializer so we can create an unserializable error.
val transfer = new NioBlockTransferService(conf, securityMgr)
store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master,
- new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer)
+ new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr)
// The put should fail since a1 is not serializable.
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 28f766570e96f..1eaabb93adbed 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -102,7 +102,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
val (blockId, subIterator) = iterator.next()
- assert(subIterator.isDefined,
+ assert(subIterator.isSuccess,
s"iterator should have 5 elements defined but actually has $i elements")
// Make sure we release the buffer once the iterator is exhausted.
@@ -230,8 +230,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
sem.acquire()
// The first block should be defined, and the last two are not defined (due to failure)
- assert(iterator.next()._2.isDefined === true)
- assert(iterator.next()._2.isDefined === false)
- assert(iterator.next()._2.isDefined === false)
+ assert(iterator.next()._2.isSuccess)
+ assert(iterator.next()._2.isFailure)
+ assert(iterator.next()._2.isFailure)
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 6567c5ab836e7..2608ad4b32e1e 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -115,8 +115,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// Go through all the failure cases to make sure we are counting them as failures.
val taskFailedReasons = Seq(
Resubmitted,
- new FetchFailed(null, 0, 0, 0),
- new ExceptionFailure("Exception", "description", null, None),
+ new FetchFailed(null, 0, 0, 0, "ignored"),
+ ExceptionFailure("Exception", "description", null, null, None),
TaskResultLost,
TaskKilled,
ExecutorLostFailure("0"),
diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala
new file mode 100644
index 0000000000000..3755d43e25ea8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.scalatest.FunSuite
+
+class CompletionIteratorSuite extends FunSuite {
+ test("basic test") {
+ var numTimesCompleted = 0
+ val iter = List(1, 2, 3).iterator
+ val completionIter = CompletionIterator[Int, Iterator[Int]](iter, { numTimesCompleted += 1 })
+
+ assert(completionIter.hasNext)
+ assert(completionIter.next() === 1)
+ assert(numTimesCompleted === 0)
+
+ assert(completionIter.hasNext)
+ assert(completionIter.next() === 2)
+ assert(numTimesCompleted === 0)
+
+ assert(completionIter.hasNext)
+ assert(completionIter.next() === 3)
+ assert(numTimesCompleted === 0)
+
+ assert(!completionIter.hasNext)
+ assert(numTimesCompleted === 1)
+
+ // SPARK-4264: Calling hasNext should not trigger the completion callback again.
+ assert(!completionIter.hasNext)
+ assert(numTimesCompleted === 1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index d235d7a0ed839..39e69851e7e3c 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -107,8 +107,9 @@ class JsonProtocolSuite extends FunSuite {
testJobResult(jobFailed)
// TaskEndReason
- val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19)
- val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None)
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "Some exception")
+ val exceptionFailure = new ExceptionFailure(exception, None)
testTaskEndReason(Success)
testTaskEndReason(Resubmitted)
testTaskEndReason(fetchFailed)
@@ -126,6 +127,13 @@ class JsonProtocolSuite extends FunSuite {
testBlockId(StreamBlockId(1, 2L))
}
+ test("ExceptionFailure backward compatibility") {
+ val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None)
+ val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure)
+ .removeField({ _._1 == "Full Stack Trace" })
+ assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
test("StageInfo backward compatibility") {
val info = makeStageInfo(1, 2, 3, 4L, 5L)
val newJson = JsonProtocol.stageInfoToJson(info)
@@ -176,6 +184,17 @@ class JsonProtocolSuite extends FunSuite {
deserializedBmRemoved)
}
+ test("FetchFailed backwards compatibility") {
+ // FetchFailed in Spark 1.1.0 does not have an "Message" property.
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "ignored")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed)
+ .removeField({ _._1 == "Message" })
+ val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "Unknown reason")
+ assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
test("SparkListenerApplicationStart backwards compatibility") {
// SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property.
val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user")
@@ -184,6 +203,15 @@ class JsonProtocolSuite extends FunSuite {
assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent))
}
+ test("ExecutorLostFailure backward compatibility") {
+ // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property.
+ val executorLostFailure = ExecutorLostFailure("100")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure)
+ .removeField({ _._1 == "Executor ID" })
+ val expectedExecutorLostFailure = ExecutorLostFailure("Unknown")
+ assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
/** -------------------------- *
| Helper test running methods |
* --------------------------- */
@@ -396,10 +424,12 @@ class JsonProtocolSuite extends FunSuite {
assert(r1.mapId === r2.mapId)
assert(r1.reduceId === r2.reduceId)
assertEquals(r1.bmAddress, r2.bmAddress)
+ assert(r1.message === r2.message)
case (r1: ExceptionFailure, r2: ExceptionFailure) =>
assert(r1.className === r2.className)
assert(r1.description === r2.description)
assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
+ assert(r1.fullStackTrace === r2.fullStackTrace)
assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
case (TaskResultLost, TaskResultLost) =>
case (TaskKilled, TaskKilled) =>
diff --git a/dev/run-tests b/dev/run-tests
index 0e9eefa76a18b..de607e4344453 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
#+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test")
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 4cc0b1f2e5116..238ddae15545e 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -99,14 +99,11 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package
mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package
{% endhighlight %}
-
-
# Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
add the `-Phive` profile to your existing build options. By default Spark
will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using
-the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only
-supported for Hive 0.12.0.
+the `-Phive-0.12.0` profile.
{% highlight bash %}
# Apache Hadoop 2.4.X with Hive 13 support
mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
@@ -121,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o
Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence:
- mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package
- mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test
+ mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package
+ mvn -Pyarn -Phadoop-2.3 -Phive test
The ScalaTest plugin also supports running only a specific test suite as follows:
@@ -185,16 +182,16 @@ can be set to control the SBT build. For example:
Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive test
To run only a specific test suite as follows:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite"
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite"
To run test suites of a specific sub project as follows:
- sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test
# Speeding up Compilation with Zinc
diff --git a/docs/configuration.md b/docs/configuration.md
index 685101ea5c9c9..0f9eb81f6e993 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -21,16 +21,22 @@ application. These properties can be set directly on a
[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your
`SparkContext`. `SparkConf` allows you to configure some of the common properties
(e.g. master URL and application name), as well as arbitrary key-value pairs through the
-`set()` method. For example, we could initialize an application as follows:
+`set()` method. For example, we could initialize an application with two threads as follows:
+
+Note that we run with local[2], meaning two threads - which represents "minimal" parallelism,
+which can help detect bugs that only exist when we run in a distributed context.
{% highlight scala %}
val conf = new SparkConf()
- .setMaster("local")
+ .setMaster("local[2]")
.setAppName("CountingSheep")
.set("spark.executor.memory", "1g")
val sc = new SparkContext(conf)
{% endhighlight %}
+Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually
+require one to prevent any sort of starvation issues.
+
## Dynamically Loading Spark Properties
In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For
instance, if you'd like to run the same application with different masters or different
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 7f9d4c6563944..d5b044d94fdd7 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel =
return new Tuple2(model.predict(p.features()), p.label());
}
});
-double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+double accuracy = predictionAndLabel.filter(new Function, Boolean>() {
@Override public Boolean call(Tuple2 pl) {
- return pl._1() == pl._2();
+ return pl._1().equals(pl._2());
}
- }).count() / test.count();
+ }).count() / (double) test.count();
{% endhighlight %}
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index 10a5131c07414..ca8c29218f52d 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) {
{% endhighlight %}
+
+[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to
+run Pearson's chi-squared tests. The following example demonstrates how to run and interpret
+hypothesis tests.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.linalg import Vectors, Matrices
+from pyspark.mllib.regresssion import LabeledPoint
+from pyspark.mllib.stat import Statistics
+
+sc = SparkContext()
+
+vec = Vectors.dense(...) # a vector composed of the frequencies of events
+
+# compute the goodness of fit. If a second vector to test against is not supplied as a parameter,
+# the test runs against a uniform distribution.
+goodnessOfFitTestResult = Statistics.chiSqTest(vec)
+print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom,
+ # test statistic, the method used, and the null hypothesis.
+
+mat = Matrices.dense(...) # a contingency matrix
+
+# conduct Pearson's independence test on the input contingency matrix
+independenceTestResult = Statistics.chiSqTest(mat)
+print independenceTestResult # summary of the test including the p-value, degrees of freedom...
+
+obs = sc.parallelize(...) # LabeledPoint(feature, label) .
+
+# The contingency table is constructed from an RDD of LabeledPoint and used to conduct
+# the independence test. Returns an array containing the ChiSquaredTestResult for every feature
+# against the label.
+featureTestResults = Statistics.chiSqTest(obs)
+
+for i, result in enumerate(featureTestResults):
+ print "Column $d:" % (i + 1)
+ print result
+{% endhighlight %}
+
+
## Random data generation
diff --git a/docs/security.md b/docs/security.md
index ec0523184d665..1e206a139fb72 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can
* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret.
* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
-* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.*
## Web UI
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index d4ade939c3a6e..e399fecbbc78c 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
spark.sql.parquet.cacheMetadata
-
false
+
true
Turns on caching of Parquet schema metadata. Can speed up querying of static data.
spark.sql.parquet.compression.codec
-
snappy
+
gzip
Sets the compression codec use when writing Parquet files. Acceptable values include:
uncompressed, snappy, gzip, lzo.
+
+
spark.sql.hive.convertMetastoreParquet
+
true
+
+ When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in
+ support.
+
+
## JSON Datasets
@@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
Property Name
Default
Meaning
spark.sql.inMemoryColumnarStorage.compressed
-
false
+
true
When set to true Spark SQL will automatically select a compression codec for each column based
on statistics of the data.
@@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
spark.sql.inMemoryColumnarStorage.batchSize
-
1000
+
10000
Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
and compression, but risk OOMs when caching data.
@@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar
Property Name
Default
Meaning
spark.sql.autoBroadcastJoinThreshold
-
10000
+
10485760 (10 MB)
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 8bbba88b31978..44a1f3ad7560b 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -68,7 +68,9 @@ import org.apache.spark._
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-// Create a local StreamingContext with two working thread and batch interval of 1 second
+// Create a local StreamingContext with two working thread and batch interval of 1 second.
+// The master requires 2 cores to prevent from a starvation scenario.
+
val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount")
val ssc = new StreamingContext(conf, Seconds(1))
{% endhighlight %}
@@ -586,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver](
A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are:
-##### Points to remember:
+##### Points to remember
{:.no_toc}
-- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
-- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data.
-
+- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
+- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor.
+Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally.
+See [Spark Properties] (configuration.html#spark-properties.html).
+
### Basic Sources
{:.no_toc}
diff --git a/ec2/spark-ec2 b/ec2/spark-ec2
index 31f9771223e51..4aa908242eeaa 100755
--- a/ec2/spark-ec2
+++ b/ec2/spark-ec2
@@ -18,5 +18,9 @@
# limitations under the License.
#
-cd "`dirname $0`"
-PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@"
+# Preserve the user's CWD so that relative paths are passed correctly to
+#+ the underlying Python script.
+SPARK_EC2_DIR="$(dirname $0)"
+
+PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \
+ python "${SPARK_EC2_DIR}/spark_ec2.py" "$@"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 0d6b82b4944f3..a5396c2375915 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -40,9 +40,11 @@
from boto import ec2
DEFAULT_SPARK_VERSION = "1.1.0"
+SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
+MESOS_SPARK_EC2_BRANCH = "v4"
# A URL prefix from which to fetch AMI information
-AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
+AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
class UsageError(Exception):
@@ -583,10 +585,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
- ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4")
+ ssh(
+ host=master,
+ opts=opts,
+ command="rm -rf spark-ec2"
+ + " && "
+ + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH)
+ )
print "Deploying files to master..."
- deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules)
+ deploy_files(
+ conn=conn,
+ root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
+ opts=opts,
+ master_nodes=master_nodes,
+ slave_nodes=slave_nodes,
+ modules=modules
+ )
print "Running setup on master..."
setup_spark_cluster(master, opts)
@@ -723,6 +738,8 @@ def get_num_disks(instance_type):
# cluster (e.g. lists of masters and slaves). Files are only deployed to
# the first master instance in the cluster, and we expect the setup
# script to be run on that instance to copy them to other nodes.
+#
+# root_dir should be an absolute path to the directory with the files we want to deploy.
def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
active_master = master_nodes[0].public_dns_name
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
new file mode 100644
index 0000000000000..1af2067b2b929
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoosting;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTrees {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTrees " +
+ " ");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.weakLearnerParams().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
+
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+ @Override public Boolean call(Tuple2 pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function, Double>() {
+ @Override public Double call(Tuple2 pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000000..540dae785f6ea
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py "
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000000..f8d83f4ec7327
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 49751a30491d0..63f02cf7b98b9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -154,20 +154,30 @@ object DecisionTreeRunner {
}
}
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
- val sc = new SparkContext(conf)
-
- println(s"DecisionTreeRunner with parameters:\n$params")
-
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset, number of classes),
+ * where the number of classes is inferred from data (and set to 0 for Regression)
+ */
+ private[mllib] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: Algo,
+ fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
// Load training data and cache it.
- val origExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ val origExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
}
// For classification, re-index classes if needed.
- val (examples, classIndexMap, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -205,14 +215,14 @@ object DecisionTreeRunner {
}
// Create training, test sets.
- val splits = if (params.testInput != "") {
+ val splits = if (testInput != "") {
// Load testInput.
val numFeatures = examples.take(1)(0).features.size
- val origTestExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+ val origTestExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
- params.algo match {
+ algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
@@ -229,17 +239,31 @@ object DecisionTreeRunner {
}
} else {
// Split input into training, test.
- examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ examples.randomSplit(Array(1.0 - fracTest, fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
+
val numTraining = training.count()
val numTest = test.count()
-
println(s"numTraining = $numTraining, numTest = $numTest.")
examples.unpersist(blocking = false)
+ (training, test, numClasses)
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+ params.testInput, params.algo, params.fracTest)
+
val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
@@ -338,7 +362,9 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ private[mllib] def meanSquaredError(
+ tree: WeightedEnsembleModel,
+ data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
new file mode 100644
index 0000000000000..9b6db01448be0
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoosting
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTrees {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTrees with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainClassifier(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoosting.trainRegressor(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 8796c28db8a66..91a0a860d6c71 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -106,9 +106,11 @@ object MovieLensALS {
Logger.getRootLogger.setLevel(Level.WARN)
+ val implicitPrefs = params.implicitPrefs
+
val ratings = sc.textFile(params.input).map { line =>
val fields = line.split("::")
- if (params.implicitPrefs) {
+ if (implicitPrefs) {
/*
* MovieLens ratings are on a scale of 1-5:
* 5: Must see
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index 4520beb991515..2b6137be25547 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -45,8 +45,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
// adding them to the index
if (edgeArray.length > 0) {
- index.update(srcIds(0), 0)
- var currSrcId: VertexId = srcIds(0)
+ index.update(edgeArray(0).srcId, 0)
+ var currSrcId: VertexId = edgeArray(0).srcId
var i = 0
while (i < edgeArray.size) {
srcIds(i) = edgeArray(i).srcId
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fb7239e779aae..87a7ddaba97f2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -45,6 +45,11 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ org.eclipse.jettyjetty-server
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index acdc67ddc660a..d832ae34b55e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -43,6 +43,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
+import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -454,6 +455,31 @@ class PythonMLLibAPI extends Serializable {
Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method))
}
+ /**
+ * Java stub for mllib Statistics.chiSqTest()
+ */
+ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
+ if (expected == null) {
+ Statistics.chiSqTest(observed)
+ } else {
+ Statistics.chiSqTest(observed, expected)
+ }
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(observed: Matrix)
+ */
+ def chiSqTest(observed: Matrix): ChiSqTestResult = {
+ Statistics.chiSqTest(observed)
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint])
+ */
+ def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = {
+ Statistics.chiSqTest(data.rdd)
+ }
+
// used by the corr methods to retrieve the name of the correlation method passed in via pyspark
private def getCorrNameOrDefault(method: String) = {
if (method == null) CorrelationNames.defaultCorrName else method
@@ -736,7 +762,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
- new PythonRDD.AutoBatchedPickler(iter)
+ new SerDeUtil.AutoBatchedPickler(iter)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
index 7858ec602483f..078fbfbe4f0e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
@@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve {
*/
def of(curve: RDD[(Double, Double)]): Double = {
curve.sliding(2).aggregate(0.0)(
- seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
+ seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
combOp = _ + _
)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f49f7..ac217edc619ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -74,6 +78,65 @@ sealed trait Vector extends Serializable {
}
}
+/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
+}
+
/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
@@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index b5e403bc8c14d..57c0768084e41 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd
import scala.language.implicitConversions
import scala.reflect.ClassTag
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
@@ -28,8 +29,8 @@ import org.apache.spark.util.Utils
/**
* Machine learning specific RDD functions.
*/
-private[mllib]
-class RDDFunctions[T: ClassTag](self: RDD[T]) {
+@DeveloperApi
+class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
/**
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
@@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
* greater than 1.
*/
- def sliding(windowSize: Int): RDD[Seq[T]] = {
+ def sliding(windowSize: Int): RDD[Array[T]] = {
require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
if (windowSize == 1) {
- self.map(Seq(_))
+ self.map(Array(_))
} else {
new SlidingRDD[T](self, windowSize)
}
@@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
}
}
-private[mllib]
+@DeveloperApi
object RDDFunctions {
/** Implicit conversion from an RDD to RDDFunctions. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
index dd80782c0f001..35e81fcb3de0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]
*/
private[mllib]
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
- extends RDD[Seq[T]](parent) {
+ extends RDD[Array[T]](parent) {
require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
- override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
val part = split.asInstanceOf[SlidingRDDPartition[T]]
(firstParent[T].iterator(part.prev, context) ++ part.tail)
.sliding(windowSize)
.withPartial(false)
+ .map(_.toArray)
}
override def getPreferredLocations(split: Partition): Seq[String] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
index 1a847201ce157..f729344a682e2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -17,30 +17,49 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
-
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
-import org.apache.spark.Logging
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.loss.Losses
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
/**
* :: Experimental ::
- * A class that implements gradient boosting for regression and binary classification problems.
+ * A class that implements Stochastic Gradient Boosting
+ * for regression and binary classification problems.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes:
+ * - This currently can be run with several loss functions. However, only SquaredError is
+ * fully supported. Specifically, the loss function should be used to compute the gradient
+ * (to re-label training instances on each iteration) and to weight weak hypotheses.
+ * Currently, gradients are computed correctly for the available loss functions,
+ * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
+ * Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ *
* @param boostingStrategy Parameters for the gradient boosting algorithm
*/
@Experimental
class GradientBoosting (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+ boostingStrategy.weakLearnerParams.algo = Regression
+ boostingStrategy.weakLearnerParams.impurity = impurity.Variance
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ boostingStrategy.weakLearnerParams.numClassesForClassification =
+ boostingStrategy.numClassesForClassification
+
+ boostingStrategy.assertValid()
+
/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -51,6 +70,7 @@ class GradientBoosting (
algo match {
case Regression => GradientBoosting.boost(input, boostingStrategy)
case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoosting.boost(remappedInput, boostingStrategy)
case _ =>
@@ -118,120 +138,32 @@ object GradientBoosting extends Logging {
}
/**
- * Method to train a gradient boosting binary classification model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
*/
- def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
- }
-
- /**
- * Method to train a gradient boosting regression model.
- *
- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * For classification, labels should take values {0, 1, ..., numClasses-1}.
- * For regression, labels are real numbers.
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
- * @param loss Loss function used for minimization during gradient boosting.
- * @param learningRate Learning rate for shrinking the contribution of each estimator. The
- * learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
- * @param numClassesForClassification Number of classes for classification.
- * (Ignored for regression.)
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example,
- * an entry (n -> k) implies the feature n is categorical with k
- * categories 0, 1, 2, ... , k-1. It's important to note that
- * features are zero-indexed.
- * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
- * supported.)
- * @return WeightedEnsembleModel that can be used for prediction
- */
- def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: Map[Int, Int],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- val lossType = Losses.fromString(loss)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
- learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
- weakLearnerParams)
- new GradientBoosting(boostingStrategy).train(input)
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ train(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
*/
def trainClassifier(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainClassifier(input.rdd, boostingStrategy)
}
/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
*/
def trainRegressor(
- input: RDD[LabeledPoint],
- numEstimators: Int,
- loss: String,
- learningRate: Double,
- subsamplingRate: Double,
- numClassesForClassification: Int,
- categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
- weakLearnerParams: Strategy): WeightedEnsembleModel = {
- trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
- numClassesForClassification,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
- weakLearnerParams)
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ trainRegressor(input.rdd, boostingStrategy)
}
-
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
@@ -247,15 +179,17 @@ object GradientBoosting extends Logging {
timer.start("init")
// Initialize gradient boosting parameters
- val numEstimators = boostingStrategy.numEstimators
- val baseLearners = new Array[DecisionTreeModel](numEstimators)
- val baseLearnerWeights = new Array[Double](numEstimators)
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
val strategy = boostingStrategy.weakLearnerParams
// Cache input
- input.persist(StorageLevel.MEMORY_AND_DISK)
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
timer.stop("init")
@@ -264,7 +198,7 @@ object GradientBoosting extends Logging {
logDebug("##########")
var data = input
- // 1. Initialize tree
+ // Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(strategy).train(data)
baseLearners(0) = firstTreeModel
@@ -280,7 +214,7 @@ object GradientBoosting extends Logging {
point.features))
var m = 1
- while (m < numEstimators) {
+ while (m < numIterations) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
@@ -289,6 +223,9 @@ object GradientBoosting extends Logging {
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
@@ -305,8 +242,6 @@ object GradientBoosting extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- // 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 501d9ff9ea9b7..abbda040bd528 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -21,7 +21,6 @@ import scala.beans.BeanProperty
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
/**
@@ -30,46 +29,58 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @param numEstimators Number of estimators used in boosting stages. In other words,
- * number of boosting iterations performed.
+ * @param numIterations Number of iterations of boosting. In other words, the number of
+ * weak hypotheses used in the final model.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
- * @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
+ * This setting overrides any setting in [[weakLearnerParams]].
* Default value is 2 (binary classification).
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
- * number of discrete values they take. For example, an entry (n ->
- * k) implies the feature n is categorical with k categories 0,
- * 1, 2, ... , k-1. It's important to note that features are
- * zero-indexed.
* @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
* supported.
*/
@Experimental
case class BoostingStrategy(
// Required boosting parameters
- algo: Algo,
- @BeanProperty var numEstimators: Int,
+ @BeanProperty var algo: Algo,
+ @BeanProperty var numIterations: Int,
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var learningRate: Double = 0.1,
- @BeanProperty var subsamplingRate: Double = 1.0,
@BeanProperty var numClassesForClassification: Int = 2,
- @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@BeanProperty var weakLearnerParams: Strategy) extends Serializable {
- require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
- s"$learningRate.")
- require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
- s"$learningRate.")
-
// Ensure values for weak learner are the same as what is provided to the boosting algorithm.
- weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
weakLearnerParams.numClassesForClassification = numClassesForClassification
- weakLearnerParams.subsamplingRate = subsamplingRate
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
+ private[tree] def assertValid(): Unit = {
+ algo match {
+ case Classification =>
+ require(numClassesForClassification == 2)
+ case Regression =>
+ // nothing
+ case _ =>
+ throw new IllegalArgumentException(
+ s"BoostingStrategy given invalid algo parameter: $algo." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(learningRate > 0 && learningRate <= 1,
+ "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+ }
}
@Experimental
@@ -82,28 +93,17 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStrategy = defaultWeakLearnerParams(algo)
+ def defaultParams(algo: String): BoostingStrategy = {
+ val treeStrategy = Strategy.defaultStrategy("Regression")
+ treeStrategy.maxDepth = 3
algo match {
- case Classification =>
- new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
- case Regression =>
- new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+ case "Classification" =>
+ new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy)
+ case "Regression" =>
+ new BoostingStrategy(Algo.withName(algo), 100, SquaredError,
+ weakLearnerParams = treeStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
}
}
-
- /**
- * Returns default configuration for the weak learner (decision tree) algorithm
- * @param algo Learning goal. Supported:
- * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
- * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
- * @return Configuration for weak learner
- */
- def defaultWeakLearnerParams(algo: Algo): Strategy = {
- // Note: Regression tree used even for classification for GBT.
- new Strategy(Regression, Variance, 3)
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d09295c507d67..b5b1f82177edc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -70,7 +70,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
*/
@Experimental
class Strategy (
- val algo: Algo,
+ @BeanProperty var algo: Algo,
@BeanProperty var impurity: Impurity,
@BeanProperty var maxDepth: Int,
@BeanProperty var numClassesForClassification: Int = 2,
@@ -85,17 +85,9 @@ class Strategy (
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- if (algo == Classification) {
- require(numClassesForClassification >= 2)
- }
- require(minInstancesPerNode >= 1,
- s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
- require(maxMemoryInMB <= 10240,
- s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
- val isMulticlassClassification =
+ def isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
- val isMulticlassWithCategoricalFeatures
+ def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
@@ -112,6 +104,23 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Sets categoricalFeaturesInfo using a Java Map.
+ */
+ def setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+ setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
/**
* Check validity of parameters.
* Throws exception if invalid.
@@ -143,6 +152,26 @@ class Strategy (
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+}
+
+@Experimental
+object Strategy {
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo "Classification" or "Regression"
+ */
+ def defaultStrategy(algo: String): Strategy = algo match {
+ case "Classification" =>
+ new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+ numClassesForClassification = 2)
+ case "Regression" =>
+ new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+ numClassesForClassification = 0)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe2d2ddf..93a84fe07b32a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}
+
+ test("VectorUDT") {
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0, 2.0)
+ val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+ val udt = new VectorUDT()
+ for (v <- Seq(dv0, dv1, sv0, sv1)) {
+ assert(v === udt.deserialize(udt.serialize(v)))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 27a19f793242b..4ef67a40b9f49 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
val rdd = sc.parallelize(data, data.length).flatMap(s => s)
assert(rdd.partitions.size === data.length)
- val sliding = rdd.sliding(3)
- val expected = data.flatMap(x => x).sliding(3).toList
- assert(sliding.collect().toList === expected)
+ val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq)
+ val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
+ assert(sliding === expected)
}
test("treeAggregate") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
index 970fff82215e2..99a02eda60baf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -22,9 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
-import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
-import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.LocalSparkContext
@@ -34,9 +33,8 @@ import org.apache.spark.mllib.util.LocalSparkContext
class GradientBoostingSuite extends FunSuite with LocalSparkContext {
test("Regression with continuous features: SquaredError") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -48,11 +46,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, 1, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -63,9 +61,8 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
test("Regression with continuous features: Absolute Error") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -77,11 +74,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Regression, numIterations, SquaredError,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
@@ -91,11 +88,9 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
}
}
-
test("Binary classification with continuous features: Log Loss") {
-
GradientBoostingSuite.testCombinations.foreach {
- case (numEstimators, learningRate, subsamplingRate) =>
+ case (numIterations, learningRate, subsamplingRate) =>
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
val categoricalFeaturesInfo = Map.empty[Int, Int]
@@ -107,11 +102,11 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
val dt = DecisionTree.train(remappedInput, treeStrategy)
- val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
- subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+ val boostingStrategy = new BoostingStrategy(Classification, numIterations, LogLoss,
+ learningRate, numClassesForClassification = 2, treeStrategy)
val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
- assert(gbt.weakHypotheses.size === numEstimators)
+ assert(gbt.weakHypotheses.size === numIterations)
val gbtTree = gbt.weakHypotheses(0)
EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
@@ -126,7 +121,6 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext {
object GradientBoostingSuite {
// Combinations for estimators, learning rates and subsamplingRate
- val testCombinations
- = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+ val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index c0a62e00432a3..5cb433232e714 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
test("BaggedPoint RDD: without subsampling") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
baggedRDD.collect().foreach { baggedPoint =>
assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
}
@@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
diff --git a/network/common/pom.xml b/network/common/pom.xml
index ea887148d98ba..6144548a8f998 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -50,6 +50,7 @@
com.google.guavaguava
+ 11.0.2provided
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index a271841e4e56c..5bc6e5a2418a9 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,12 +17,16 @@
package org.apache.spark.network;
+import java.util.List;
+
+import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.MessageDecoder;
@@ -64,8 +68,17 @@ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this.decoder = new MessageDecoder();
}
+ /**
+ * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+ * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+ * to create a Client.
+ */
+ public TransportClientFactory createClientFactory(List bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
public TransportClientFactory createClientFactory() {
- return new TransportClientFactory(this);
+ return createClientFactory(Lists.newArrayList());
}
/** Create a server which will attempt to bind to a specific port. */
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
index 89ed79bc63903..5fa1527ddff92 100644
--- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -30,6 +30,7 @@
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.LimitedInputStream;
/**
* A {@link ManagedBuffer} backed by a segment in a file.
@@ -101,7 +102,7 @@ public InputStream createInputStream() throws IOException {
try {
is = new FileInputStream(file);
ByteStreams.skipFully(is, offset);
- return ByteStreams.limit(is, length);
+ return new LimitedInputStream(is, length);
} catch (IOException e) {
try {
if (is != null) {
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 01c143fff423c..4e944114e8176 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -18,11 +18,12 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
@@ -117,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeFetchRequest(streamChunkId);
- callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -148,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
serverAddr, future.cause());
logger.error(errorMsg, future.cause());
handler.removeRpcRequest(requestId);
- callback.onFailure(new RuntimeException(errorMsg, future.cause()));
channel.close();
+ try {
+ callback.onFailure(new IOException(errorMsg, future.cause()));
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!", e);
+ }
}
}
});
@@ -176,6 +185,8 @@ public void onFailure(Throwable e) {
try {
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (ExecutionException e) {
+ throw Throwables.propagate(e.getCause());
} catch (Exception e) {
throw Throwables.propagate(e);
}
@@ -186,4 +197,12 @@ public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("remoteAdress", channel.remoteAddress())
+ .add("isActive", isActive())
+ .toString();
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000000000..65e8020e34121
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /** Performs the bootstrapping operation, throwing an exception on failure. */
+ public void doBootstrap(TransportClient client) throws RuntimeException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 0b4a1d8286407..397d3a8455c86 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -18,13 +18,17 @@
package org.apache.spark.network.client;
import java.io.Closeable;
+import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
@@ -47,22 +51,29 @@
* Factory for creating {@link TransportClient}s by using createClient.
*
* The factory maintains a connection pool to other hosts and should return the same
- * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
- * all {@link TransportClient}s.
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
+ private final List clientBootstraps;
private final ConcurrentHashMap connectionPool;
private final Class extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
- public TransportClientFactory(TransportContext context) {
- this.context = context;
+ public TransportClientFactory(
+ TransportContext context,
+ List clientBootstraps) {
+ this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
+ this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
@@ -72,21 +83,26 @@ public TransportClientFactory(TransportContext context) {
}
/**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
+ * Create a new {@link TransportClient} connecting to the given remote host / port. This will
+ * reuse TransportClients if they are still active and are for the same remote address. Prior
+ * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
+ * that are registered with this factory.
*
- * This blocks until a connection is successfully established.
+ * This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency: This method is safe to call from multiple threads.
*/
- public TransportClient createClient(String remoteHost, int remotePort) {
+ public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
+ logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
}
}
@@ -104,34 +120,55 @@ public TransportClient createClient(String remoteHost, int remotePort) {
// Use pooled buffers to reduce temporary buffer allocation
bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
- final AtomicReference client = new AtomicReference();
+ final AtomicReference clientRef = new AtomicReference();
bootstrap.handler(new ChannelInitializer() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
- client.set(clientHandler.getClient());
+ clientRef.set(clientHandler.getClient());
}
});
// Connect to the remote server
+ long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
- throw new RuntimeException(
+ throw new IOException(
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
} else if (cf.cause() != null) {
- throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ TransportClient client = clientRef.get();
+ assert client != null : "Channel future completed successfully with null client";
+
+ // Execute any client bootstraps synchronously before marking the Client as successful.
+ long preBootstrap = System.currentTimeMillis();
+ logger.debug("Connection to {} successful, running bootstraps...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+ long bootstrapTime = System.currentTimeMillis() - preBootstrap;
+ logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ client.close();
+ throw Throwables.propagate(e);
}
+ long postBootstrap = System.currentTimeMillis();
- // Successful connection -- in the event that two threads raced to create a client, we will
+ // Successful connection & bootstrap -- in the event that two threads raced to create a client,
// use the first one that was put into the connectionPool and close the one we made here.
- assert client.get() != null : "Channel future completed successfully with null client";
- TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
- return client.get();
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, postBootstrap - preConnect, postBootstrap - preBootstrap);
+ return client;
} else {
- logger.debug("Two clients were created concurrently, second one will be disposed.");
- client.get().close();
+ logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
+ postBootstrap - preConnect);
+ client.close();
return oldClient;
}
}
@@ -162,7 +199,7 @@ public void close() {
*/
private PooledByteBufAllocator createPooledByteBufAllocator() {
return new PooledByteBufAllocator(
- PlatformDependent.directBufferPreferred(),
+ conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
getPrivateStaticField("DEFAULT_PAGE_SIZE"),
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index d8965590b34da..2044afb0d85db 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.client;
+import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@@ -94,7 +95,7 @@ public void channelUnregistered() {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
logger.error("Still have {} requests outstanding when connection from {} is closed",
numOutstandingRequests(), remoteAddress);
- failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 4cb8becc3ed22..91d1e8a538a77 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List